Skip to content

CUDA out of memory #18

@ybu-lxd

Description

@ybu-lxd
class KanMLP(nn.Module):
    """Some Information about KanLinear"""
    def __init__(self,
              in_features=1152,
              hidden_features = None,
              out_features = None,
               drop=0.
              ):
        super().__init__()
        
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.mlp = nn.ModuleDict(
            dict(
                c_fc=KAN(width=[in_features, hidden_features]),
                c_proj=KAN(width=[hidden_features, out_features]),
                act=NewGELU(),
                dropout=nn.Dropout(0.0),
            )
        )
        m = self.mlp
        self.mlpf = lambda x: m.dropout(
            m.c_proj(m.act(m.c_fc(x)))
        )  # MLP forward



        
    def forward(self, x):
        x = self.mlpf(x)
        return x

net = KanMLP(1152,1152*4).to("cuda")
x = torch.rand(size=(4,4096*4,1152)).to("cuda")
nex(x)

When the number of tokens reaches a certain size, the following situation will occur

 CUDA out of memory.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is needed

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions