Skip to content

Commit 66ea248

Browse files
committed
why in_length is needed ? if n_atom_properties is available
1 parent dc4d3a7 commit 66ea248

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

chebai_graph/models/_gat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ class GATModelWrapper(GraphBaseNet):
1313
def __init__(self, config: dict, **kwargs):
1414
super().__init__(**kwargs)
1515

16-
self._in_length = int(config.pop("in_length"))
1716
self._hidden_length = int(config.pop("hidden_length"))
1817
self._dropout_rate = float(config.pop("dropout_rate", 0.1))
1918
self._n_conv_layers = int(config.pop("n_conv_layers", 3))
@@ -22,10 +21,11 @@ def __init__(self, config: dict, **kwargs):
2221
self._n_bond_properties = int(config.pop("n_bond_properties", 0))
2322
self._n_molecule_properties = int(config.pop("n_molecule_properties", 0))
2423
self._gat = GAT(
25-
in_channels=self._in_length,
24+
in_channels=self._n_atom_properties,
2625
hidden_channels=self._hidden_length,
2726
num_layers=self._n_conv_layers,
2827
dropout=self._dropout_rate,
28+
edge_dim=self._n_bond_properties,
2929
**config,
3030
)
3131

configs/model/gat.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@ init_args:
33
optimizer_kwargs:
44
lr: 1e-3
55
config:
6-
in_length: 256
76
hidden_length: 512
87
dropout_rate: 0.1
98
n_conv_layers: 3
10-
heads: 4 # the number of heads should be divisible by output channels (hidden channels if output channel not given)
9+
heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given)
1110
# v2: True # -- to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv
1211
n_linear_layers: 3
1312
n_atom_properties: 158

0 commit comments

Comments
 (0)