|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | + |
| 4 | +class AWDLSTMModel(nn.Module): |
| 5 | + def __init__(self, vocab_size, emb_dim, hidden_dim, n_layers): |
| 6 | + super().__init__() |
| 7 | + self.embedding = nn.Embedding(vocab_size, emb_dim) |
| 8 | + # num_layers stacks multiple LSTMs automatically |
| 9 | + self.lstm = nn.LSTM(emb_dim, hidden_dim, num_layers=n_layers) |
| 10 | + self.linear = nn.Linear(hidden_dim, vocab_size) |
| 11 | + |
| 12 | + def forward(self, x, hidden=None): |
| 13 | + # x shape: (seq_len, batch) |
| 14 | + emb = self.embedding(x) |
| 15 | + |
| 16 | + # 'output' contains the hidden state h_t for EVERY timestep |
| 17 | + # but only for the LAST layer in the stack. |
| 18 | + # 'h_n' contains the LAST timestep hidden state for EVERY layer. |
| 19 | + output, (h_n, c_n) = self.lstm(emb, hidden) |
| 20 | + |
| 21 | + # We return 'output' because AR and TAR require the full |
| 22 | + # sequence of hidden states from the final layer. |
| 23 | + decoded = self.linear(output) |
| 24 | + return decoded, output, (h_n, c_n) |
| 25 | + |
| 26 | +# Example usage: |
| 27 | +# model_output, final_layer_h_seq, _ = model(input_data) |
| 28 | +# loss = awd_loss(model_output, targets, final_layer_h_seq) |
0 commit comments