Skip to content

Commit 0e6afe2

Browse files
committed
add num_layers and dropout parameters, make lstm bidirectional
1 parent 4288689 commit 0e6afe2

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

chebai/models/lstm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99

1010

1111
class ChemLSTM(ChebaiBaseNet):
12-
def __init__(self, out_d, in_d, num_classes, criterion: nn.Module = None, **kwargs):
12+
def __init__(self, out_d, in_d, num_classes, criterion: nn.Module = None, num_layers=6, dropout=0.2, **kwargs):
1313
super().__init__(
1414
out_dim=out_d,
1515
input_dim=in_d,
1616
criterion=criterion,
1717
num_classes=num_classes,
1818
**kwargs,
1919
)
20-
self.lstm = nn.LSTM(in_d, out_d, batch_first=True, dropout=0.2)
20+
self.lstm = nn.LSTM(in_d, out_d, batch_first=True, dropout=dropout, bidirectional=True, num_layers=num_layers)
2121
self.embedding = nn.Embedding(1400, in_d)
2222
self.output = nn.Sequential(
2323
nn.Linear(out_d, out_d),

0 commit comments

Comments
 (0)