Skip to content

Commit 51e609e

Browse files
committed
why input channels needed? when n_atom_properties is there
1 parent 0a39749 commit 51e609e

File tree

1 file changed

+17
-21
lines changed

1 file changed

+17
-21
lines changed

chebai_graph/models/graph.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import typing
3+
from abc import ABC
34

45
import torch
56
import torch.nn.functional as F
@@ -15,7 +16,7 @@
1516
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
1617

1718

18-
class GraphBaseNet(ChebaiBaseNet):
19+
class GraphBaseNet(ChebaiBaseNet, ABC):
1920
def _get_prediction_and_labels(self, data, labels, output):
2021
return torch.sigmoid(output), labels.int()
2122

@@ -104,26 +105,26 @@ def __init__(self, config: typing.Dict, **kwargs):
104105

105106
self.activation = F.elu
106107
self.dropout = nn.Dropout(self.dropout_rate)
107-
108108
self.convs = torch.nn.ModuleList([])
109-
for i in range(self.n_conv_layers):
110-
if i == 0:
111-
self.convs.append(
112-
tgnn.ResGatedGraphConv(
113-
self.n_atom_properties,
114-
self.in_length,
115-
# dropout=self.dropout_rate,
116-
edge_dim=self.n_bond_properties,
117-
)
118-
)
109+
110+
self.convs.append(
111+
tgnn.ResGatedGraphConv(
112+
self.n_atom_properties,
113+
self.hidden_length,
114+
# dropout=self.dropout_rate,
115+
edge_dim=self.n_bond_properties,
116+
)
117+
)
118+
119+
for _ in range(self.n_conv_layers - 1):
119120
self.convs.append(
120121
tgnn.ResGatedGraphConv(
121-
self.in_length, self.in_length, edge_dim=self.n_bond_properties
122+
self.hidden_length,
123+
self.hidden_length,
124+
# dropout=self.dropout_rate,
125+
edge_dim=self.n_bond_properties,
122126
)
123127
)
124-
self.final_conv = tgnn.ResGatedGraphConv(
125-
self.in_length, self.hidden_length, edge_dim=self.n_bond_properties
126-
)
127128

128129
def forward(self, batch):
129130
graph_data = batch["features"][0]
@@ -136,11 +137,6 @@ def forward(self, batch):
136137
a = self.activation(
137138
conv(a, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr)
138139
)
139-
a = self.activation(
140-
self.final_conv(
141-
a, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr
142-
)
143-
)
144140
return a
145141

146142

0 commit comments

Comments
 (0)