Skip to content

Commit 940ce9d

Browse files
committed
multi-layer lstm
1 parent 0e6afe2 commit 940ce9d

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

chebai/models/lstm.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,42 @@
11
import logging
22

33
from torch import nn
4-
from torch.nn.utils.rnn import pack_padded_sequence
4+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
55

66
from chebai.models.base import ChebaiBaseNet
77

88
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
99

1010

1111
class ChemLSTM(ChebaiBaseNet):
12-
def __init__(self, out_d, in_d, num_classes, criterion: nn.Module = None, num_layers=6, dropout=0.2, **kwargs):
12+
def __init__(
13+
self,
14+
out_d,
15+
in_d,
16+
num_classes,
17+
criterion: nn.Module = None,
18+
num_layers=6,
19+
dropout=0.2,
20+
**kwargs,
21+
):
1322
super().__init__(
1423
out_dim=out_d,
1524
input_dim=in_d,
1625
criterion=criterion,
1726
num_classes=num_classes,
1827
**kwargs,
1928
)
20-
self.lstm = nn.LSTM(in_d, out_d, batch_first=True, dropout=dropout, bidirectional=True, num_layers=num_layers)
29+
self.lstm = nn.LSTM(
30+
in_d,
31+
out_d,
32+
batch_first=True,
33+
dropout=dropout,
34+
bidirectional=True,
35+
num_layers=num_layers,
36+
)
2137
self.embedding = nn.Embedding(1400, in_d)
2238
self.output = nn.Sequential(
23-
nn.Linear(out_d, out_d),
39+
nn.Linear(out_d * 2, out_d),
2440
nn.ReLU(),
2541
nn.Dropout(0.2),
2642
nn.Linear(out_d, num_classes),
@@ -31,7 +47,9 @@ def forward(self, data, *args, **kwargs):
3147
x_lens = data["model_kwargs"]["lens"]
3248
x = self.embedding(x)
3349
x = pack_padded_sequence(x, x_lens, batch_first=True, enforce_sorted=False)
34-
x = self.lstm(x)[1][0]
35-
# = pad_packed_sequence(x, batch_first=True)[0]
50+
x = self.lstm(x)[0]
51+
x = pad_packed_sequence(x, batch_first=True)[0][
52+
:, 0
53+
] # reduce sequence dimension to first element
3654
x = self.output(x)
37-
return x.squeeze(0)
55+
return x

0 commit comments

Comments
 (0)