1- from typing import Final , Tuple , Union
1+ from typing import Any , Final
22
3- import torch
4- import torch .nn .functional as F
3+ from torch import Tensor
54from torch .nn import ELU
65from torch_geometric import nn as tgnn
76from torch_geometric .data import Data as GraphData
1110from .base import GraphModelBase , GraphNetWrapper
1211
1312
14- class ResGatedGraphConvNetBase (GraphModelBase ):
15- """
16- Residual Gated Graph Convolutional Network with edge attributes support.
17-
18- This model uses a stack of `ResGatedGraphConv` layers from PyTorch Geometric,
19- allowing edge attributes as part of message passing. A final projection layer maps
20- to the hidden length specified for downstream graph prediction tasks.
21- """
22-
23- NAME = "ResGatedGraphConvNetBase"
24-
25- def __init__ (self , config : dict , ** kwargs ):
26- """
27- Initialize the ResGatedGraphConvNetBase.
28-
29- Args:
30- config (dict): Configuration dictionary with keys:
31- - 'hidden_length' (int): Intermediate feature length used in GNN layers.
32- - Other parameters inherited from GraphModelBase.
33- **kwargs: Additional keyword arguments passed to GraphModelBase.
34- """
35- super ().__init__ (config = config , ** kwargs )
36-
37- self .activation = F .elu
38- self .convs = torch .nn .ModuleList ()
39- self .convs .append (
40- tgnn .ResGatedGraphConv (
41- self .n_node_properties ,
42- self .hidden_channels ,
43- # dropout=self.dropout,
44- edge_dim = self .n_bond_properties ,
45- )
46- )
47-
48- for _ in range (self .num_layers - 2 ):
49- # Intermediate layers
50- self .convs .append (
51- tgnn .ResGatedGraphConv (
52- self .hidden_channels ,
53- self .hidden_channels ,
54- edge_dim = self .n_bond_properties ,
55- )
56- )
57-
58- # Final projection layer to hidden dimension
59- self .final_conv = tgnn .ResGatedGraphConv (
60- self .hidden_channels , self .out_channels , edge_dim = self .n_bond_properties
61- )
62-
63- def forward (self , batch : dict ) -> torch .Tensor :
64- """
65- Forward pass through residual gated GNN layers.
66-
67- Args:
68- batch (dict): A batch containing:
69- - 'features': A list with a `GraphData` instance as the first element.
70-
71- Returns:
72- torch.Tensor: Node-level embeddings of shape [num_nodes, hidden_length].
73- """
74- graph_data = batch ["features" ][0 ]
75- assert isinstance (graph_data , GraphData )
76-
77- x = graph_data .x .float () # Atom features
78-
79- for conv in self .convs :
80- assert isinstance (conv , tgnn .ResGatedGraphConv )
81- x = self .activation (
82- conv (x , graph_data .edge_index .long (), edge_attr = graph_data .edge_attr )
83- )
84-
85- x = self .activation (
86- self .final_conv (
87- x , graph_data .edge_index .long (), edge_attr = graph_data .edge_attr
88- )
89- )
90-
91- return x
92-
93-
94- class ResGatedGraphPred (GraphNetWrapper ):
13+ class ResGatedModel (BasicGNN ):
9514 """
96- Residual Gated GNN for Graph Prediction .
15+ A residual gated GNN model based on PyG's BasicGNN using ResGatedGraphConv layers .
9716
98- Uses `ResGatedGraphConvNetBase` as the GNN encoder to compute node embeddings.
17+ Attributes:
18+ supports_edge_weight (bool): Indicates edge weights are not supported.
19+ supports_edge_attr (bool): Indicates edge attributes are supported.
20+ supports_norm_batch (bool): Indicates if batch normalization is supported.
9921 """
10022
101- NAME = "ResGatedGraphPred"
23+ supports_edge_weight : Final [bool ] = False
24+ supports_edge_attr : Final [bool ] = True
25+ supports_norm_batch : Final [bool ]
10226
103- def _get_gnn (self , config : dict ) -> ResGatedGraphConvNetBase :
27+ def init_conv (
28+ self , in_channels : int | tuple [int , int ], out_channels : int , ** kwargs : Any
29+ ) -> MessagePassing :
10430 """
105- Instantiate the residual gated GNN backbone .
31+ Initializes a ResGatedGraphConv layer .
10632
10733 Args:
108- config (dict): Model configuration.
34+ in_channels (int or Tuple[int, int]): Number of input channels.
35+ out_channels (int): Number of output channels.
36+ **kwargs: Additional keyword arguments for the convolution layer.
10937
11038 Returns:
111- ResGatedGraphConvNetBase: The GNN encoder .
39+ MessagePassing: A ResGatedGraphConv layer instance .
11240 """
113- return ResGatedGraphConvNetBase (config = config )
114-
115-
116- class ResGatedModel (BasicGNN ):
117- supports_edge_weight : Final [bool ] = False
118- supports_edge_attr : Final [bool ] = True
119- supports_norm_batch : Final [bool ]
120-
121- def init_conv (
122- self , in_channels : Union [int , Tuple [int , int ]], out_channels : int , ** kwargs
123- ) -> MessagePassing :
12441 return tgnn .ResGatedGraphConv (
12542 in_channels ,
12643 out_channels ,
12744 ** kwargs ,
12845 )
12946
13047
131- class ResGatedPyG (GraphModelBase ):
48+ class ResGatedGraphConvNetBase (GraphModelBase ):
13249 """
133- Graph Attention Network (GAT) base module for graph convolution .
50+ Base model class for applying ResGatedGraphConv layers to graph-structured data .
13451
135- Uses PyTorch Geometric's `GAT` implementation to process atomic node features
136- and bond edge attributes through multiple attention heads and layers.
52+ Args:
53+ config (dict): Configuration dictionary containing model hyperparameters.
54+ **kwargs: Additional keyword arguments for parent class.
13755 """
13856
139- def __init__ (self , config : dict , ** kwargs ):
140- """
141- Initialize the GATGraphConvNetBase.
142-
143- Args:
144- config (dict): Model configuration containing:
145- - 'heads' (int): Number of attention heads.
146- - 'v2' (bool): Whether to use the GATv2 variant.
147- - Other required GraphModelBase parameters.
148- **kwargs: Additional arguments for the base class.
149- """
57+ def __init__ (self , config : dict [str , Any ], ** kwargs : Any ):
15058 super ().__init__ (config = config , ** kwargs )
15159 self .activation = ELU () # Instantiate ELU once for reuse.
152- self .gat = ResGatedModel (
153- in_channels = self .n_node_properties ,
60+
61+ self .resgated : BasicGNN = ResGatedModel (
62+ in_channels = self .in_channels ,
15463 hidden_channels = self .hidden_channels ,
15564 out_channels = self .out_channels ,
15665 num_layers = self .num_layers ,
157- edge_dim = self .n_bond_properties ,
66+ edge_dim = self .edge_dim ,
15867 act = self .activation ,
15968 )
16069
161- def forward (self , batch : dict ) -> torch . Tensor :
70+ def forward (self , batch : dict [ str , Any ] ) -> Tensor :
16271 """
163- Forward pass through the GAT network.
164-
165- Processes atomic node features and edge attributes, and applies
166- an ELU activation to the output.
72+ Forward pass of the model.
16773
16874 Args:
169- batch (dict): Input batch containing:
170- - 'features': A list with a `GraphData` object as its first element.
75+ batch (dict): A batch containing graph input features under the key "features".
17176
17277 Returns:
173- torch. Tensor: Node embeddings after GAT and activation.
78+ Tensor: The output node-level embeddings after the final activation.
17479 """
17580 graph_data = batch ["features" ][0 ]
176- assert isinstance (graph_data , GraphData )
81+ assert isinstance (graph_data , GraphData ), "Expected GraphData instance"
17782
178- out = self .gat (
83+ out = self .resgated (
17984 x = graph_data .x .float (),
18085 edge_index = graph_data .edge_index .long (),
18186 edge_attr = graph_data .edge_attr ,
@@ -184,23 +89,21 @@ def forward(self, batch: dict) -> torch.Tensor:
18489 return self .activation (out )
18590
18691
187- class ResGatedGraphPredPyG (GraphNetWrapper ):
92+ class ResGatedGraphPred (GraphNetWrapper ):
18893 """
189- Residual Gated GNN for Graph Prediction .
94+ Wrapper for graph-level prediction using ResGatedGraphConvNetBase .
19095
191- Uses `ResGatedGraphConvNetBase` as the GNN encoder to compute node embeddings .
96+ This class instantiates the core GNN model using the provided config .
19297 """
19398
194- NAME = "ResGatedGraphPred"
195-
196- def _get_gnn (self , config : dict ) -> ResGatedPyG :
99+ def _get_gnn (self , config : dict [str , Any ]) -> ResGatedGraphConvNetBase :
197100 """
198- Instantiate the residual gated GNN backbone .
101+ Returns the core ResGated GNN model .
199102
200103 Args:
201- config (dict): Model configuration .
104+ config (dict): Configuration dictionary for the GNN model .
202105
203106 Returns:
204- ResGatedGraphConvNetBase: The GNN encoder .
107+ ResGatedGraphConvNetBase: The core graph convolutional network .
205108 """
206- return ResGatedPyG (config = config )
109+ return ResGatedGraphConvNetBase (config = config )
0 commit comments