1+ from typing import Final , Tuple , Union
2+
13import torch
24import torch .nn .functional as F
3- from torch import nn
5+ from torch . nn import ELU
46from torch_geometric import nn as tgnn
57from torch_geometric .data import Data as GraphData
8+ from torch_geometric .nn .conv import MessagePassing
9+ from torch_geometric .nn .models .basic_gnn import BasicGNN
610
711from .base import GraphModelBase , GraphNetWrapper
812
@@ -24,38 +28,36 @@ def __init__(self, config: dict, **kwargs):
2428
2529 Args:
2630 config (dict): Configuration dictionary with keys:
27- - 'in_length ' (int): Intermediate feature length used in GNN layers.
31+ - 'hidden_length ' (int): Intermediate feature length used in GNN layers.
2832 - Other parameters inherited from GraphModelBase.
2933 **kwargs: Additional keyword arguments passed to GraphModelBase.
3034 """
3135 super ().__init__ (config = config , ** kwargs )
32- self .in_length = int (config ["in_length" ])
3336
3437 self .activation = F .elu
35- self .dropout = nn .Dropout (self .dropout_rate )
36-
3738 self .convs = torch .nn .ModuleList ()
38- for i in range (self .n_conv_layers ):
39- if i == 0 :
40- # Initial layer uses atom features as input
41- self .convs .append (
42- tgnn .ResGatedGraphConv (
43- self .n_atom_properties ,
44- self .in_length ,
45- # dropout=self.dropout_rate,
46- edge_dim = self .n_bond_properties ,
47- )
48- )
39+ self .convs .append (
40+ tgnn .ResGatedGraphConv (
41+ self .n_node_properties ,
42+ self .hidden_channels ,
43+ # dropout=self.dropout,
44+ edge_dim = self .n_bond_properties ,
45+ )
46+ )
47+
48+ for _ in range (self .num_layers - 2 ):
4949 # Intermediate layers
5050 self .convs .append (
5151 tgnn .ResGatedGraphConv (
52- self .in_length , self .in_length , edge_dim = self .n_bond_properties
52+ self .hidden_channels ,
53+ self .hidden_channels ,
54+ edge_dim = self .n_bond_properties ,
5355 )
5456 )
5557
5658 # Final projection layer to hidden dimension
5759 self .final_conv = tgnn .ResGatedGraphConv (
58- self .in_length , self .hidden_length , edge_dim = self .n_bond_properties
60+ self .hidden_channels , self .out_channels , edge_dim = self .n_bond_properties
5961 )
6062
6163 def forward (self , batch : dict ) -> torch .Tensor :
@@ -109,3 +111,96 @@ def _get_gnn(self, config: dict) -> ResGatedGraphConvNetBase:
109111 ResGatedGraphConvNetBase: The GNN encoder.
110112 """
111113 return ResGatedGraphConvNetBase (config = config )
114+
115+
116+ class ResGatedModel (BasicGNN ):
117+ supports_edge_weight : Final [bool ] = False
118+ supports_edge_attr : Final [bool ] = True
119+ supports_norm_batch : Final [bool ]
120+
121+ def init_conv (
122+ self , in_channels : Union [int , Tuple [int , int ]], out_channels : int , ** kwargs
123+ ) -> MessagePassing :
124+ return tgnn .ResGatedGraphConv (
125+ in_channels ,
126+ out_channels ,
127+ ** kwargs ,
128+ )
129+
130+
131+ class ResGatedPyG (GraphModelBase ):
132+ """
133+ Graph Attention Network (GAT) base module for graph convolution.
134+
135+ Uses PyTorch Geometric's `GAT` implementation to process atomic node features
136+ and bond edge attributes through multiple attention heads and layers.
137+ """
138+
139+ def __init__ (self , config : dict , ** kwargs ):
140+ """
141+ Initialize the GATGraphConvNetBase.
142+
143+ Args:
144+ config (dict): Model configuration containing:
145+ - 'heads' (int): Number of attention heads.
146+ - 'v2' (bool): Whether to use the GATv2 variant.
147+ - Other required GraphModelBase parameters.
148+ **kwargs: Additional arguments for the base class.
149+ """
150+ super ().__init__ (config = config , ** kwargs )
151+ self .activation = ELU () # Instantiate ELU once for reuse.
152+ self .gat = ResGatedModel (
153+ in_channels = self .n_node_properties ,
154+ hidden_channels = self .hidden_channels ,
155+ out_channels = self .out_channels ,
156+ num_layers = self .num_layers ,
157+ edge_dim = self .n_bond_properties ,
158+ act = self .activation ,
159+ )
160+
161+ def forward (self , batch : dict ) -> torch .Tensor :
162+ """
163+ Forward pass through the GAT network.
164+
165+ Processes atomic node features and edge attributes, and applies
166+ an ELU activation to the output.
167+
168+ Args:
169+ batch (dict): Input batch containing:
170+ - 'features': A list with a `GraphData` object as its first element.
171+
172+ Returns:
173+ torch.Tensor: Node embeddings after GAT and activation.
174+ """
175+ graph_data = batch ["features" ][0 ]
176+ assert isinstance (graph_data , GraphData )
177+
178+ out = self .gat (
179+ x = graph_data .x .float (),
180+ edge_index = graph_data .edge_index .long (),
181+ edge_attr = graph_data .edge_attr ,
182+ )
183+
184+ return self .activation (out )
185+
186+
187+ class ResGatedGraphPredPyG (GraphNetWrapper ):
188+ """
189+ Residual Gated GNN for Graph Prediction.
190+
191+ Uses `ResGatedGraphConvNetBase` as the GNN encoder to compute node embeddings.
192+ """
193+
194+ NAME = "ResGatedGraphPred"
195+
196+ def _get_gnn (self , config : dict ) -> ResGatedPyG :
197+ """
198+ Instantiate the residual gated GNN backbone.
199+
200+ Args:
201+ config (dict): Model configuration.
202+
203+ Returns:
204+ ResGatedGraphConvNetBase: The GNN encoder.
205+ """
206+ return ResGatedPyG (config = config )
0 commit comments