11import logging
22
33from torch import nn
4- from torch .nn .utils .rnn import pack_padded_sequence
4+ from torch .nn .utils .rnn import pack_padded_sequence , pad_packed_sequence
55
66from chebai .models .base import ChebaiBaseNet
77
88logging .getLogger ("pysmiles" ).setLevel (logging .CRITICAL )
99
1010
1111class ChemLSTM (ChebaiBaseNet ):
12- def __init__ (self , out_d , in_d , num_classes , criterion : nn .Module = None , num_layers = 6 , dropout = 0.2 , ** kwargs ):
12+ def __init__ (
13+ self ,
14+ out_d ,
15+ in_d ,
16+ num_classes ,
17+ criterion : nn .Module = None ,
18+ num_layers = 6 ,
19+ dropout = 0.2 ,
20+ ** kwargs ,
21+ ):
1322 super ().__init__ (
1423 out_dim = out_d ,
1524 input_dim = in_d ,
1625 criterion = criterion ,
1726 num_classes = num_classes ,
1827 ** kwargs ,
1928 )
20- self .lstm = nn .LSTM (in_d , out_d , batch_first = True , dropout = dropout , bidirectional = True , num_layers = num_layers )
29+ self .lstm = nn .LSTM (
30+ in_d ,
31+ out_d ,
32+ batch_first = True ,
33+ dropout = dropout ,
34+ bidirectional = True ,
35+ num_layers = num_layers ,
36+ )
2137 self .embedding = nn .Embedding (1400 , in_d )
2238 self .output = nn .Sequential (
23- nn .Linear (out_d , out_d ),
39+ nn .Linear (out_d * 2 , out_d ),
2440 nn .ReLU (),
2541 nn .Dropout (0.2 ),
2642 nn .Linear (out_d , num_classes ),
@@ -31,7 +47,9 @@ def forward(self, data, *args, **kwargs):
3147 x_lens = data ["model_kwargs" ]["lens" ]
3248 x = self .embedding (x )
3349 x = pack_padded_sequence (x , x_lens , batch_first = True , enforce_sorted = False )
34- x = self .lstm (x )[1 ][0 ]
35- # = pad_packed_sequence(x, batch_first=True)[0]
50+ x = self .lstm (x )[0 ]
51+ x = pad_packed_sequence (x , batch_first = True )[0 ][
52+ :, 0
53+ ] # reduce sequence dimension to first element
3654 x = self .output (x )
37- return x . squeeze ( 0 )
55+ return x
0 commit comments