11import logging
22import typing
3+ from abc import ABC
34
45import torch
56import torch .nn .functional as F
1516logging .getLogger ("pysmiles" ).setLevel (logging .CRITICAL )
1617
1718
18- class GraphBaseNet (ChebaiBaseNet ):
19+ class GraphBaseNet (ChebaiBaseNet , ABC ):
1920 def _get_prediction_and_labels (self , data , labels , output ):
2021 return torch .sigmoid (output ), labels .int ()
2122
@@ -104,26 +105,26 @@ def __init__(self, config: typing.Dict, **kwargs):
104105
105106 self .activation = F .elu
106107 self .dropout = nn .Dropout (self .dropout_rate )
107-
108108 self .convs = torch .nn .ModuleList ([])
109- for i in range (self .n_conv_layers ):
110- if i == 0 :
111- self .convs .append (
112- tgnn .ResGatedGraphConv (
113- self .n_atom_properties ,
114- self .in_length ,
115- # dropout=self.dropout_rate,
116- edge_dim = self .n_bond_properties ,
117- )
118- )
109+
110+ self .convs .append (
111+ tgnn .ResGatedGraphConv (
112+ self .n_atom_properties ,
113+ self .hidden_length ,
114+ # dropout=self.dropout_rate,
115+ edge_dim = self .n_bond_properties ,
116+ )
117+ )
118+
119+ for _ in range (self .n_conv_layers - 1 ):
119120 self .convs .append (
120121 tgnn .ResGatedGraphConv (
121- self .in_length , self .in_length , edge_dim = self .n_bond_properties
122+ self .hidden_length ,
123+ self .hidden_length ,
124+ # dropout=self.dropout_rate,
125+ edge_dim = self .n_bond_properties ,
122126 )
123127 )
124- self .final_conv = tgnn .ResGatedGraphConv (
125- self .in_length , self .hidden_length , edge_dim = self .n_bond_properties
126- )
127128
128129 def forward (self , batch ):
129130 graph_data = batch ["features" ][0 ]
@@ -136,11 +137,6 @@ def forward(self, batch):
136137 a = self .activation (
137138 conv (a , graph_data .edge_index .long (), edge_attr = graph_data .edge_attr )
138139 )
139- a = self .activation (
140- self .final_conv (
141- a , graph_data .edge_index .long (), edge_attr = graph_data .edge_attr
142- )
143- )
144140 return a
145141
146142
0 commit comments