Skip to content

Commit a23306d

Browse files
Create NN.lstm.pyx
1 parent a825b8f commit a23306d

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

Modellib/NN.lstm.pyx

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
class AWDLSTMModel(nn.Module):
5+
def __init__(self, vocab_size, emb_dim, hidden_dim, n_layers):
6+
super().__init__()
7+
self.embedding = nn.Embedding(vocab_size, emb_dim)
8+
# num_layers stacks multiple LSTMs automatically
9+
self.lstm = nn.LSTM(emb_dim, hidden_dim, num_layers=n_layers)
10+
self.linear = nn.Linear(hidden_dim, vocab_size)
11+
12+
def forward(self, x, hidden=None):
13+
# x shape: (seq_len, batch)
14+
emb = self.embedding(x)
15+
16+
# 'output' contains the hidden state h_t for EVERY timestep
17+
# but only for the LAST layer in the stack.
18+
# 'h_n' contains the LAST timestep hidden state for EVERY layer.
19+
output, (h_n, c_n) = self.lstm(emb, hidden)
20+
21+
# We return 'output' because AR and TAR require the full
22+
# sequence of hidden states from the final layer.
23+
decoded = self.linear(output)
24+
return decoded, output, (h_n, c_n)
25+
26+
# Example usage:
27+
# model_output, final_layer_h_seq, _ = model(input_data)
28+
# loss = awd_loss(model_output, targets, final_layer_h_seq)

0 commit comments

Comments
 (0)