Skip to content

Commit f776af1

Browse files
committed
restructure resgated params - add pyg impl to compare
1 parent 6aa18d6 commit f776af1

File tree

3 files changed

+126
-33
lines changed

3 files changed

+126
-33
lines changed

chebai_graph/models/base.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,17 @@ def __init__(self, config: dict, **kwargs) -> None:
5252
- 'hidden_length'
5353
- 'dropout_rate'
5454
- 'n_conv_layers'
55-
- 'n_atom_properties'
55+
- 'n_node_properties'
5656
- 'n_bond_properties'
5757
**kwargs: Additional keyword arguments for torch.nn.Module.
5858
"""
5959
super().__init__(**kwargs)
60-
self.hidden_length = int(config["hidden_length"])
61-
self.dropout_rate = float(config["dropout_rate"])
62-
self.n_conv_layers = int(config["n_conv_layers"])
63-
self.n_atom_properties = int(config["n_atom_properties"])
64-
self.n_bond_properties = int(config["n_bond_properties"])
60+
self.hidden_channels = int(config["hidden_channels"])
61+
self.out_channels = int(config["out_channels"])
62+
self.num_layers = int(config["num_layers"])
63+
assert self.num_layers > 1, "Need atleast two convolution layers"
64+
self.n_node_properties = int(config["n_node_properties"]) # in_channels
65+
self.n_bond_properties = int(config["n_bond_properties"]) # edge_dim
6566

6667

6768
class GraphNetWrapper(GraphBaseNet, ABC):
@@ -83,9 +84,7 @@ def __init__(
8384
"""
8485
super().__init__(**kwargs)
8586
self.gnn = self._get_gnn(config)
86-
gnn_out_dim = (
87-
config["out_dim"] if "out_dim" in config else config["hidden_length"]
88-
)
87+
gnn_out_dim = int(config["out_channels"])
8988
self.activation = torch.nn.ELU
9089
self.lin_input_dim = self._get_lin_seq_input_dim(
9190
gnn_out_dim=gnn_out_dim,

chebai_graph/models/resgated.py

Lines changed: 113 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
from typing import Final, Tuple, Union
2+
13
import torch
24
import torch.nn.functional as F
3-
from torch import nn
5+
from torch.nn import ELU
46
from torch_geometric import nn as tgnn
57
from torch_geometric.data import Data as GraphData
8+
from torch_geometric.nn.conv import MessagePassing
9+
from torch_geometric.nn.models.basic_gnn import BasicGNN
610

711
from .base import GraphModelBase, GraphNetWrapper
812

@@ -24,38 +28,36 @@ def __init__(self, config: dict, **kwargs):
2428
2529
Args:
2630
config (dict): Configuration dictionary with keys:
27-
- 'in_length' (int): Intermediate feature length used in GNN layers.
31+
- 'hidden_length' (int): Intermediate feature length used in GNN layers.
2832
- Other parameters inherited from GraphModelBase.
2933
**kwargs: Additional keyword arguments passed to GraphModelBase.
3034
"""
3135
super().__init__(config=config, **kwargs)
32-
self.in_length = int(config["in_length"])
3336

3437
self.activation = F.elu
35-
self.dropout = nn.Dropout(self.dropout_rate)
36-
3738
self.convs = torch.nn.ModuleList()
38-
for i in range(self.n_conv_layers):
39-
if i == 0:
40-
# Initial layer uses atom features as input
41-
self.convs.append(
42-
tgnn.ResGatedGraphConv(
43-
self.n_atom_properties,
44-
self.in_length,
45-
# dropout=self.dropout_rate,
46-
edge_dim=self.n_bond_properties,
47-
)
48-
)
39+
self.convs.append(
40+
tgnn.ResGatedGraphConv(
41+
self.n_node_properties,
42+
self.hidden_channels,
43+
# dropout=self.dropout,
44+
edge_dim=self.n_bond_properties,
45+
)
46+
)
47+
48+
for _ in range(self.num_layers - 2):
4949
# Intermediate layers
5050
self.convs.append(
5151
tgnn.ResGatedGraphConv(
52-
self.in_length, self.in_length, edge_dim=self.n_bond_properties
52+
self.hidden_channels,
53+
self.hidden_channels,
54+
edge_dim=self.n_bond_properties,
5355
)
5456
)
5557

5658
# Final projection layer to hidden dimension
5759
self.final_conv = tgnn.ResGatedGraphConv(
58-
self.in_length, self.hidden_length, edge_dim=self.n_bond_properties
60+
self.hidden_channels, self.out_channels, edge_dim=self.n_bond_properties
5961
)
6062

6163
def forward(self, batch: dict) -> torch.Tensor:
@@ -109,3 +111,96 @@ def _get_gnn(self, config: dict) -> ResGatedGraphConvNetBase:
109111
ResGatedGraphConvNetBase: The GNN encoder.
110112
"""
111113
return ResGatedGraphConvNetBase(config=config)
114+
115+
116+
class ResGatedModel(BasicGNN):
117+
supports_edge_weight: Final[bool] = False
118+
supports_edge_attr: Final[bool] = True
119+
supports_norm_batch: Final[bool]
120+
121+
def init_conv(
122+
self, in_channels: Union[int, Tuple[int, int]], out_channels: int, **kwargs
123+
) -> MessagePassing:
124+
return tgnn.ResGatedGraphConv(
125+
in_channels,
126+
out_channels,
127+
**kwargs,
128+
)
129+
130+
131+
class ResGatedPyG(GraphModelBase):
132+
"""
133+
Graph Attention Network (GAT) base module for graph convolution.
134+
135+
Uses PyTorch Geometric's `GAT` implementation to process atomic node features
136+
and bond edge attributes through multiple attention heads and layers.
137+
"""
138+
139+
def __init__(self, config: dict, **kwargs):
140+
"""
141+
Initialize the GATGraphConvNetBase.
142+
143+
Args:
144+
config (dict): Model configuration containing:
145+
- 'heads' (int): Number of attention heads.
146+
- 'v2' (bool): Whether to use the GATv2 variant.
147+
- Other required GraphModelBase parameters.
148+
**kwargs: Additional arguments for the base class.
149+
"""
150+
super().__init__(config=config, **kwargs)
151+
self.activation = ELU() # Instantiate ELU once for reuse.
152+
self.gat = ResGatedModel(
153+
in_channels=self.n_node_properties,
154+
hidden_channels=self.hidden_channels,
155+
out_channels=self.out_channels,
156+
num_layers=self.num_layers,
157+
edge_dim=self.n_bond_properties,
158+
act=self.activation,
159+
)
160+
161+
def forward(self, batch: dict) -> torch.Tensor:
162+
"""
163+
Forward pass through the GAT network.
164+
165+
Processes atomic node features and edge attributes, and applies
166+
an ELU activation to the output.
167+
168+
Args:
169+
batch (dict): Input batch containing:
170+
- 'features': A list with a `GraphData` object as its first element.
171+
172+
Returns:
173+
torch.Tensor: Node embeddings after GAT and activation.
174+
"""
175+
graph_data = batch["features"][0]
176+
assert isinstance(graph_data, GraphData)
177+
178+
out = self.gat(
179+
x=graph_data.x.float(),
180+
edge_index=graph_data.edge_index.long(),
181+
edge_attr=graph_data.edge_attr,
182+
)
183+
184+
return self.activation(out)
185+
186+
187+
class ResGatedGraphPredPyG(GraphNetWrapper):
188+
"""
189+
Residual Gated GNN for Graph Prediction.
190+
191+
Uses `ResGatedGraphConvNetBase` as the GNN encoder to compute node embeddings.
192+
"""
193+
194+
NAME = "ResGatedGraphPred"
195+
196+
def _get_gnn(self, config: dict) -> ResGatedPyG:
197+
"""
198+
Instantiate the residual gated GNN backbone.
199+
200+
Args:
201+
config (dict): Model configuration.
202+
203+
Returns:
204+
ResGatedGraphConvNetBase: The GNN encoder.
205+
"""
206+
return ResGatedPyG(config=config)

configs/model/resgated.yml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@ init_args:
33
optimizer_kwargs:
44
lr: 1e-3
55
config:
6-
in_length: 256
7-
hidden_length: 512
8-
dropout_rate: 0
9-
n_conv_layers: 3
10-
n_atom_properties: 158
11-
n_bond_properties: 7
6+
n_node_properties: 68 # in_channels
7+
hidden_channels : 256
8+
out_channels : 512
9+
num_layers : 4
10+
n_bond_properties: 4 # edge_dim
1211
n_molecule_properties: 0
1312
n_linear_layers: 2

0 commit comments

Comments
 (0)