Skip to content

3D Batch input for the mLSTM #7

@matiashaggman

Description

@matiashaggman

Hi! Thank you for your great contribution.

I am trying to see whether I can achieve a better result by using the mLSTM architecture for time series classification as opposed to a traditional LSTM.

This is part of my code:


 super().__init__()
        self.num_classes = num_classes
        self.inp_dim = num_features
        self.head_dim = 8
        self.head_num = 4
        self.hid_dim = self.head_num * self.head_dim
        
        self.batch_size = 5
        
        # Create an instance of mLSTM
        self.model = mLSTM(self.inp_dim, self.head_num, self.head_dim)

        self.hid_0 = self.model.init_hidden(self.batch_size)

        self.out = nn.Linear(in_features=self.hid_dim, out_features=self.num_classes)

    def forward(self, features:torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the network.
        
        Args:
          x (Tensor): The input tensor containing features of shape (batch_size, rec_length, n_features).
          rec_lengths (Tensor): The actual lengths of each sleep sequence before padding.

        Returns:
          Tensor: The output logits of shape (batch_size, rec_length, n_classes).
        """
        # Pass the input through the GRU layer.   
        x, _ = self.model(features, self.hid_0)
        
        # Pass the output of the GRU through the output layer to get class logits for each time step.
        y = self.out(x)
        
        return y

However when I try to run the forward pass with for example the following tensor (general shape of my data):

x = torch.rand(seq_len, batch_size, num_features)
y = model(x)

I will get the following error:

--> [306]return F.conv1d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)

RuntimeError: Given groups=1, weight of size [1, 1, 4], expected input[1199, 5, 10] to have 1 channels, but got 5 channels instead

Is it possible to run the model with 3D input? Thanks in advance.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions