Skip to content

Commit dfdd810

Browse files
committed
use pyg model impl without changing the architecture
- #12
1 parent e4e5397 commit dfdd810

File tree

6 files changed

+81
-172
lines changed

6 files changed

+81
-172
lines changed

chebai_graph/models/base.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,22 @@ def __init__(self, config: dict, **kwargs) -> None:
4949
5050
Args:
5151
config (dict): Configuration dictionary with keys:
52-
- 'hidden_length'
53-
- 'dropout_rate'
54-
- 'n_conv_layers'
55-
- 'n_node_properties'
56-
- 'n_bond_properties'
52+
- 'num_layers'
53+
- 'in_channels'
54+
- 'hidden_channels'
55+
- 'out_channels'
56+
- 'edge_dim'
57+
- 'dropout'
5758
**kwargs: Additional keyword arguments for torch.nn.Module.
5859
"""
5960
super().__init__(**kwargs)
60-
self.hidden_channels = int(config["hidden_channels"])
61-
self.out_channels = int(config["out_channels"])
6261
self.num_layers = int(config["num_layers"])
6362
assert self.num_layers > 1, "Need atleast two convolution layers"
64-
self.n_node_properties = int(config["n_node_properties"]) # in_channels
65-
self.n_bond_properties = int(config["n_bond_properties"]) # edge_dim
63+
self.in_channels = int(config["in_channels"]) # number of node/atom properties
64+
self.hidden_channels = int(config["hidden_channels"])
65+
self.out_channels = int(config["out_channels"])
66+
self.edge_dim = int(config["edge_dim"]) # number of bond properties
67+
self.dropout = float(config["dropout"])
6668

6769

6870
class GraphNetWrapper(GraphBaseNet, ABC):

chebai_graph/models/gat.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@ def __init__(self, config: dict, **kwargs):
3030
self.v2 = bool(config["v2"])
3131
self.activation = ELU() # Instantiate ELU once for reuse.
3232
self.gat = GAT(
33-
in_channels=self.n_atom_properties,
34-
hidden_channels=self.hidden_length,
35-
num_layers=self.n_conv_layers,
36-
dropout=self.dropout_rate,
37-
edge_dim=self.n_bond_properties,
33+
in_channels=self.in_channels,
34+
hidden_channels=self.hidden_channels,
35+
out_channels=self.out_channels,
36+
num_layers=self.num_layers,
37+
dropout=self.dropout,
38+
edge_dim=self.edge_dim,
3839
heads=self.heads,
3940
v2=self.v2,
4041
act=self.activation,

chebai_graph/models/resgated.py

Lines changed: 43 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from typing import Final, Tuple, Union
1+
from typing import Any, Final
22

3-
import torch
4-
import torch.nn.functional as F
3+
from torch import Tensor
54
from torch.nn import ELU
65
from torch_geometric import nn as tgnn
76
from torch_geometric.data import Data as GraphData
@@ -11,171 +10,77 @@
1110
from .base import GraphModelBase, GraphNetWrapper
1211

1312

14-
class ResGatedGraphConvNetBase(GraphModelBase):
15-
"""
16-
Residual Gated Graph Convolutional Network with edge attributes support.
17-
18-
This model uses a stack of `ResGatedGraphConv` layers from PyTorch Geometric,
19-
allowing edge attributes as part of message passing. A final projection layer maps
20-
to the hidden length specified for downstream graph prediction tasks.
21-
"""
22-
23-
NAME = "ResGatedGraphConvNetBase"
24-
25-
def __init__(self, config: dict, **kwargs):
26-
"""
27-
Initialize the ResGatedGraphConvNetBase.
28-
29-
Args:
30-
config (dict): Configuration dictionary with keys:
31-
- 'hidden_length' (int): Intermediate feature length used in GNN layers.
32-
- Other parameters inherited from GraphModelBase.
33-
**kwargs: Additional keyword arguments passed to GraphModelBase.
34-
"""
35-
super().__init__(config=config, **kwargs)
36-
37-
self.activation = F.elu
38-
self.convs = torch.nn.ModuleList()
39-
self.convs.append(
40-
tgnn.ResGatedGraphConv(
41-
self.n_node_properties,
42-
self.hidden_channels,
43-
# dropout=self.dropout,
44-
edge_dim=self.n_bond_properties,
45-
)
46-
)
47-
48-
for _ in range(self.num_layers - 2):
49-
# Intermediate layers
50-
self.convs.append(
51-
tgnn.ResGatedGraphConv(
52-
self.hidden_channels,
53-
self.hidden_channels,
54-
edge_dim=self.n_bond_properties,
55-
)
56-
)
57-
58-
# Final projection layer to hidden dimension
59-
self.final_conv = tgnn.ResGatedGraphConv(
60-
self.hidden_channels, self.out_channels, edge_dim=self.n_bond_properties
61-
)
62-
63-
def forward(self, batch: dict) -> torch.Tensor:
64-
"""
65-
Forward pass through residual gated GNN layers.
66-
67-
Args:
68-
batch (dict): A batch containing:
69-
- 'features': A list with a `GraphData` instance as the first element.
70-
71-
Returns:
72-
torch.Tensor: Node-level embeddings of shape [num_nodes, hidden_length].
73-
"""
74-
graph_data = batch["features"][0]
75-
assert isinstance(graph_data, GraphData)
76-
77-
x = graph_data.x.float() # Atom features
78-
79-
for conv in self.convs:
80-
assert isinstance(conv, tgnn.ResGatedGraphConv)
81-
x = self.activation(
82-
conv(x, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr)
83-
)
84-
85-
x = self.activation(
86-
self.final_conv(
87-
x, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr
88-
)
89-
)
90-
91-
return x
92-
93-
94-
class ResGatedGraphPred(GraphNetWrapper):
13+
class ResGatedModel(BasicGNN):
9514
"""
96-
Residual Gated GNN for Graph Prediction.
15+
A residual gated GNN model based on PyG's BasicGNN using ResGatedGraphConv layers.
9716
98-
Uses `ResGatedGraphConvNetBase` as the GNN encoder to compute node embeddings.
17+
Attributes:
18+
supports_edge_weight (bool): Indicates edge weights are not supported.
19+
supports_edge_attr (bool): Indicates edge attributes are supported.
20+
supports_norm_batch (bool): Indicates if batch normalization is supported.
9921
"""
10022

101-
NAME = "ResGatedGraphPred"
23+
supports_edge_weight: Final[bool] = False
24+
supports_edge_attr: Final[bool] = True
25+
supports_norm_batch: Final[bool]
10226

103-
def _get_gnn(self, config: dict) -> ResGatedGraphConvNetBase:
27+
def init_conv(
28+
self, in_channels: int | tuple[int, int], out_channels: int, **kwargs: Any
29+
) -> MessagePassing:
10430
"""
105-
Instantiate the residual gated GNN backbone.
31+
Initializes a ResGatedGraphConv layer.
10632
10733
Args:
108-
config (dict): Model configuration.
34+
in_channels (int or Tuple[int, int]): Number of input channels.
35+
out_channels (int): Number of output channels.
36+
**kwargs: Additional keyword arguments for the convolution layer.
10937
11038
Returns:
111-
ResGatedGraphConvNetBase: The GNN encoder.
39+
MessagePassing: A ResGatedGraphConv layer instance.
11240
"""
113-
return ResGatedGraphConvNetBase(config=config)
114-
115-
116-
class ResGatedModel(BasicGNN):
117-
supports_edge_weight: Final[bool] = False
118-
supports_edge_attr: Final[bool] = True
119-
supports_norm_batch: Final[bool]
120-
121-
def init_conv(
122-
self, in_channels: Union[int, Tuple[int, int]], out_channels: int, **kwargs
123-
) -> MessagePassing:
12441
return tgnn.ResGatedGraphConv(
12542
in_channels,
12643
out_channels,
12744
**kwargs,
12845
)
12946

13047

131-
class ResGatedPyG(GraphModelBase):
48+
class ResGatedGraphConvNetBase(GraphModelBase):
13249
"""
133-
Graph Attention Network (GAT) base module for graph convolution.
50+
Base model class for applying ResGatedGraphConv layers to graph-structured data.
13451
135-
Uses PyTorch Geometric's `GAT` implementation to process atomic node features
136-
and bond edge attributes through multiple attention heads and layers.
52+
Args:
53+
config (dict): Configuration dictionary containing model hyperparameters.
54+
**kwargs: Additional keyword arguments for parent class.
13755
"""
13856

139-
def __init__(self, config: dict, **kwargs):
140-
"""
141-
Initialize the GATGraphConvNetBase.
142-
143-
Args:
144-
config (dict): Model configuration containing:
145-
- 'heads' (int): Number of attention heads.
146-
- 'v2' (bool): Whether to use the GATv2 variant.
147-
- Other required GraphModelBase parameters.
148-
**kwargs: Additional arguments for the base class.
149-
"""
57+
def __init__(self, config: dict[str, Any], **kwargs: Any):
15058
super().__init__(config=config, **kwargs)
15159
self.activation = ELU() # Instantiate ELU once for reuse.
152-
self.gat = ResGatedModel(
153-
in_channels=self.n_node_properties,
60+
61+
self.resgated: BasicGNN = ResGatedModel(
62+
in_channels=self.in_channels,
15463
hidden_channels=self.hidden_channels,
15564
out_channels=self.out_channels,
15665
num_layers=self.num_layers,
157-
edge_dim=self.n_bond_properties,
66+
edge_dim=self.edge_dim,
15867
act=self.activation,
15968
)
16069

161-
def forward(self, batch: dict) -> torch.Tensor:
70+
def forward(self, batch: dict[str, Any]) -> Tensor:
16271
"""
163-
Forward pass through the GAT network.
164-
165-
Processes atomic node features and edge attributes, and applies
166-
an ELU activation to the output.
72+
Forward pass of the model.
16773
16874
Args:
169-
batch (dict): Input batch containing:
170-
- 'features': A list with a `GraphData` object as its first element.
75+
batch (dict): A batch containing graph input features under the key "features".
17176
17277
Returns:
173-
torch.Tensor: Node embeddings after GAT and activation.
78+
Tensor: The output node-level embeddings after the final activation.
17479
"""
17580
graph_data = batch["features"][0]
176-
assert isinstance(graph_data, GraphData)
81+
assert isinstance(graph_data, GraphData), "Expected GraphData instance"
17782

178-
out = self.gat(
83+
out = self.resgated(
17984
x=graph_data.x.float(),
18085
edge_index=graph_data.edge_index.long(),
18186
edge_attr=graph_data.edge_attr,
@@ -184,23 +89,21 @@ def forward(self, batch: dict) -> torch.Tensor:
18489
return self.activation(out)
18590

18691

187-
class ResGatedGraphPredPyG(GraphNetWrapper):
92+
class ResGatedGraphPred(GraphNetWrapper):
18893
"""
189-
Residual Gated GNN for Graph Prediction.
94+
Wrapper for graph-level prediction using ResGatedGraphConvNetBase.
19095
191-
Uses `ResGatedGraphConvNetBase` as the GNN encoder to compute node embeddings.
96+
This class instantiates the core GNN model using the provided config.
19297
"""
19398

194-
NAME = "ResGatedGraphPred"
195-
196-
def _get_gnn(self, config: dict) -> ResGatedPyG:
99+
def _get_gnn(self, config: dict[str, Any]) -> ResGatedGraphConvNetBase:
197100
"""
198-
Instantiate the residual gated GNN backbone.
101+
Returns the core ResGated GNN model.
199102
200103
Args:
201-
config (dict): Model configuration.
104+
config (dict): Configuration dictionary for the GNN model.
202105
203106
Returns:
204-
ResGatedGraphConvNetBase: The GNN encoder.
107+
ResGatedGraphConvNetBase: The core graph convolutional network.
205108
"""
206-
return ResGatedPyG(config=config)
109+
return ResGatedGraphConvNetBase(config=config)

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,9 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
267267
]
268268
rank_zero_info(
269269
f"Finished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n"
270-
f"Use n_atom_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, AtomProperty))}, "
271-
f"n_bond_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, BondProperty))}, "
270+
f"Use following values for given parameters for model configuration: \n\t"
271+
f"in_channels: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, AtomProperty))}, "
272+
f"edge_dim: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, BondProperty))}, "
272273
f"n_molecule_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, MoleculeProperty))}"
273274
)
274275

@@ -337,7 +338,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
337338
else:
338339
raise TypeError(f"Unsupported property type: {type(prop).__name__}")
339340

340-
n_atom_properties = max(
341+
n_node_properties = max(
341342
n_atom_node_properties, n_fg_node_properties, n_graph_node_properties
342343
)
343344
rank_zero_info(
@@ -347,7 +348,8 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
347348
f"n_fg_node_properties: {n_fg_node_properties}, "
348349
f"n_bond_properties: {n_bond_properties}, "
349350
f"n_graph_node_properties: {n_graph_node_properties}\n"
350-
f"Use n_atom_properties: {n_atom_properties}, n_bond_properties: {n_bond_properties}, n_molecule_properties: 0"
351+
f"Use following values for given parameters for model configuration: \n\t"
352+
f"in_channels: {n_node_properties}, edge_dim: {n_bond_properties}, n_molecule_properties: 0"
351353
)
352354

353355
for property in self.properties:
@@ -368,7 +370,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
368370
base_df["features"] = base_df.apply(
369371
lambda row: self._merge_props_into_base(
370372
row,
371-
max_len_node_properties=n_atom_properties,
373+
max_len_node_properties=n_node_properties,
372374
),
373375
axis=1,
374376
)

configs/model/gat.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ init_args:
33
optimizer_kwargs:
44
lr: 1e-3
55
config:
6-
hidden_length: 512
7-
dropout_rate: 0
8-
n_conv_layers: 3
6+
in_channels: 158 # number of node/atom properties
7+
hidden_channels: 256
8+
out_channels: 512
9+
num_layers: 5
10+
edge_dim: 7 # number of bond properties
911
heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given)
1012
v2: False # set True to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv
11-
n_atom_properties: 158
12-
n_bond_properties: 7
13-
14-
n_molecule_properties: 200
15-
n_linear_layers: 3
13+
dropout: 0
14+
n_molecule_properties: 0
15+
n_linear_layers: 2

configs/model/resgated.yml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@ init_args:
33
optimizer_kwargs:
44
lr: 1e-3
55
config:
6-
n_node_properties: 68 # in_channels
7-
hidden_channels : 256
8-
out_channels : 512
9-
num_layers : 4
10-
n_bond_properties: 4 # edge_dim
6+
in_channels: 158 # number of node/atom properties
7+
hidden_channels: 256
8+
out_channels: 512
9+
num_layers: 5
10+
edge_dim: 7 # number of bond properties
11+
dropout: 0
1112
n_molecule_properties: 0
1213
n_linear_layers: 2

0 commit comments

Comments
 (0)