@@ -25,10 +25,34 @@ class ResGatedDynamicGNI(GraphModelBase):
2525 def __init__ (self , config : dict [str , Any ], ** kwargs : Any ):
2626 super ().__init__ (config = config , ** kwargs )
2727 self .activation = ELU () # Instantiate ELU once for reuse.
28+
2829 distribution = config .get ("distribution" , "normal" )
29- assert distribution in ["normal" , "uniform" , "xavier_normal" , "xavier_uniform" ]
30+ assert distribution in RandomFeatureInitializationReader .DISTRIBUTIONS , (
31+ f"Unsupported distribution: { distribution } . "
32+ f"Choose from { RandomFeatureInitializationReader .DISTRIBUTIONS } ."
33+ )
3034 self .distribution = distribution
3135
36+ self .complete_randomness = config .get ("complete_randomness" , True )
37+
38+ if not self .complete_randomness :
39+ assert (
40+ "random_pad_node" in config or "random_pad_edge" in config
41+ ), "Missing 'random_pad_node' or 'random_pad_edge' in config when complete_randomness is False"
42+ self .random_pad_node = (
43+ int (config ["random_pad_node" ])
44+ if config .get ("random_pad_node" ) is not None
45+ else None
46+ )
47+ self .random_pad_edge = (
48+ int (config ["random_pad_edge" ])
49+ if config .get ("random_pad_edge" ) is not None
50+ else None
51+ )
52+ assert (
53+ self .random_pad_node > 0 or self .random_pad_edge > 0
54+ ), "'random_pad_node' or 'random_pad_edge' must be positive integers"
55+
3256 self .resgated : BasicGNN = ResGatedModel (
3357 in_channels = self .in_channels ,
3458 hidden_channels = self .hidden_channels ,
@@ -52,24 +76,52 @@ def forward(self, batch: dict[str, Any]) -> Tensor:
5276 graph_data = batch ["features" ][0 ]
5377 assert isinstance (graph_data , GraphData ), "Expected GraphData instance"
5478
55- random_x = torch .empty (
56- graph_data .x .shape [0 ], graph_data .x .shape [1 ], device = self .device
57- )
58- RandomFeatureInitializationReader .random_gni (random_x , self .distribution )
59-
60- random_edge_attr = torch .empty (
61- graph_data .edge_attr .shape [0 ],
62- graph_data .edge_attr .shape [1 ],
63- device = self .device ,
64- )
65- RandomFeatureInitializationReader .random_gni (
66- random_edge_attr , self .distribution
67- )
68-
79+ new_x = None
80+ new_edge_attr = None
81+ if self .complete_randomness :
82+ new_x = torch .empty (
83+ graph_data .x .shape [0 ], graph_data .x .shape [1 ], device = self .device
84+ )
85+ RandomFeatureInitializationReader .random_gni (new_x , self .distribution )
86+
87+ new_edge_attr = torch .empty (
88+ graph_data .edge_attr .shape [0 ],
89+ graph_data .edge_attr .shape [1 ],
90+ device = self .device ,
91+ )
92+ RandomFeatureInitializationReader .random_gni (
93+ new_edge_attr , self .distribution
94+ )
95+ else :
96+ if self .random_pad_node is not None :
97+ pad_node = torch .empty (
98+ graph_data .x .shape [0 ],
99+ self .random_pad_node ,
100+ device = self .device ,
101+ )
102+ RandomFeatureInitializationReader .random_gni (
103+ pad_node , self .distribution
104+ )
105+ new_x = torch .cat ((graph_data .x , pad_node ), dim = 1 )
106+
107+ if self .random_pad_edge is not None :
108+ pad_edge = torch .empty (
109+ graph_data .edge_attr .shape [0 ],
110+ self .random_pad_edge ,
111+ device = self .device ,
112+ )
113+ RandomFeatureInitializationReader .random_gni (
114+ pad_edge , self .distribution
115+ )
116+ new_edge_attr = torch .cat ((graph_data .edge_attr , pad_edge ), dim = 1 )
117+
118+ assert (
119+ new_x is not None and new_edge_attr is not None
120+ ), "Feature initialization failed"
69121 out = self .resgated (
70- x = random_x .float (),
122+ x = new_x .float (),
71123 edge_index = graph_data .edge_index .long (),
72- edge_attr = random_edge_attr .float (),
124+ edge_attr = new_edge_attr .float (),
73125 )
74126
75127 return self .activation (out )
0 commit comments