@@ -189,45 +189,29 @@ def __init__(
189189 self ,
190190 properties = None ,
191191 transform = None ,
192- zero_pad_node : int = None ,
193- zero_pad_edge : int = None ,
194- random_pad_node : int = None ,
195- random_pad_edge : int = None ,
192+ pad_node_features : int = None ,
193+ pad_edge_features : int = None ,
196194 distribution : str = "normal" ,
197195 ** kwargs ,
198196 ):
199197 super ().__init__ (properties , transform , ** kwargs )
200- self .zero_pad_node = int (zero_pad_node ) if zero_pad_node else None
201- if self .zero_pad_node :
202- print (
203- f"[Info] Node-level features will be zero-padded with "
204- f"{ self .zero_pad_node } additional dimensions."
205- )
206-
207- self .zero_pad_edge = int (zero_pad_edge ) if zero_pad_edge else None
208- if self .zero_pad_edge :
209- print (
210- f"[Info] Edge-level features will be zero-padded with "
211- f"{ self .zero_pad_edge } additional dimensions."
212- )
213-
214- self .random_pad_edge = int (random_pad_edge ) if random_pad_edge else None
215- self .random_pad_node = int (random_pad_node ) if random_pad_node else None
216- if self .random_pad_node or self .random_pad_edge :
198+ self .pad_edge_features = int (pad_edge_features ) if pad_edge_features else None
199+ self .pad_node_features = int (pad_node_features ) if pad_node_features else None
200+ if self .pad_node_features or self .pad_edge_features :
217201 assert (
218202 distribution is not None
219203 and distribution in RandomFeatureInitializationReader .DISTRIBUTIONS
220- ), "When using random padding, a valid distribution must be specified."
204+ ), "When using padding for features , a valid distribution must be specified."
221205 self .distribution = distribution
222- if self .random_pad_node :
206+ if self .pad_node_features :
223207 print (
224- f"[Info] Node-level features will be padded with "
225- f"{ self .random_pad_node } additional dimensions initialized from { self .distribution } distribution."
208+ f"[Info] Node-level features will be padded with random "
209+ f"{ self .pad_node_features } values from { self .distribution } distribution."
226210 )
227- if self .random_pad_edge :
211+ if self .pad_edge_features :
228212 print (
229- f"[Info] Edge-level features will be padded with "
230- f"{ self .random_pad_edge } additional dimensions initialized from { self .distribution } distribution."
213+ f"[Info] Edge-level features will be padded with random "
214+ f"{ self .pad_edge_features } values from { self .distribution } distribution."
231215 )
232216
233217 if self .properties :
@@ -276,24 +260,19 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
276260 else :
277261 raise TypeError (f"Unsupported property type: { type (property ).__name__ } " )
278262
279- if self .zero_pad_node :
280- x = torch .cat ([x , torch .zeros ((x .shape [0 ], self .zero_pad_node ))], dim = 1 )
281-
282- if self .zero_pad_edge :
283- edge_attr = torch .cat (
284- [edge_attr , torch .zeros ((edge_attr .shape [0 ], self .zero_pad_edge ))],
285- dim = 1 ,
263+ if self .pad_node_features :
264+ padding_values = torch .empty ((x .shape [0 ], self .pad_node_features ))
265+ RandomFeatureInitializationReader .random_gni (
266+ padding_values , self .distribution
286267 )
268+ x = torch .cat ([x , padding_values ], dim = 1 )
287269
288- if self .random_pad_node :
289- random_pad = torch .empty ((x .shape [0 ], self .random_pad_node ))
290- RandomFeatureInitializationReader .random_gni (random_pad , self .distribution )
291- x = torch .cat ([x , random_pad ], dim = 1 )
292-
293- if self .random_pad_edge :
294- random_pad = torch .empty ((edge_attr .shape [0 ], self .random_pad_edge ))
295- RandomFeatureInitializationReader .random_gni (random_pad , self .distribution )
296- edge_attr = torch .cat ([edge_attr , random_pad ], dim = 1 )
270+ if self .pad_edge_features :
271+ padding_values = torch .empty ((edge_attr .shape [0 ], self .pad_edge_features ))
272+ RandomFeatureInitializationReader .random_gni (
273+ padding_values , self .distribution
274+ )
275+ edge_attr = torch .cat ([edge_attr , padding_values ], dim = 1 )
297276
298277 return GeomData (
299278 x = x ,
@@ -350,13 +329,9 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
350329 )
351330
352331 in_channels_str = ""
353- if self .zero_pad_node :
354- n_node_properties += self .zero_pad_node
355- in_channels_str += f" (with { self .zero_pad_node } padded zeros)"
356-
357- if self .random_pad_node :
358- n_node_properties += self .random_pad_node
359- in_channels_str += f" (with { self .random_pad_node } random padded values from { self .distribution } distribution)"
332+ if self .pad_node_features :
333+ n_node_properties += self .pad_node_features
334+ in_channels_str += f" (with { self .pad_node_features } padded random values from { self .distribution } distribution)"
360335
361336 in_channels_str = f"in_channels: { n_node_properties } " + in_channels_str
362337
@@ -367,14 +342,9 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
367342 if isinstance (p , BondProperty )
368343 )
369344 edge_dim_str = ""
370-
371- if self .zero_pad_edge :
372- n_edge_properties += self .zero_pad_edge
373- edge_dim_str += f" (with { self .zero_pad_edge } padded zeros)"
374-
375- if self .random_pad_edge :
376- n_edge_properties += self .random_pad_edge
377- edge_dim_str += f" (with { self .random_pad_edge } random padded values from { self .distribution } distribution)"
345+ if self .pad_edge_features :
346+ n_edge_properties += self .pad_edge_features
347+ edge_dim_str += f" (with { self .pad_edge_features } padded random values from { self .distribution } distribution)"
378348
379349 edge_dim_str = f"edge_dim: { n_edge_properties } " + edge_dim_str
380350
0 commit comments