@@ -186,16 +186,50 @@ def _after_setup(self, **kwargs) -> None:
186186
187187class GraphPropertiesMixIn (DataPropertiesSetter , ABC ):
188188 def __init__ (
189- self , properties = None , transform = None , zero_pad_atom : int = None , ** kwargs
189+ self ,
190+ properties = None ,
191+ transform = None ,
192+ zero_pad_node : int = None ,
193+ zero_pad_edge : int = None ,
194+ random_pad_node : int = None ,
195+ random_pad_edge : int = None ,
196+ distribution : str = "normal" ,
197+ ** kwargs ,
190198 ):
191199 super ().__init__ (properties , transform , ** kwargs )
192- self .zero_pad_atom = int (zero_pad_atom ) if zero_pad_atom is not None else None
193- if self .zero_pad_atom :
200+ self .zero_pad_node = int (zero_pad_node ) if zero_pad_node else None
201+ if self .zero_pad_node :
202+ print (
203+ f"[Info] Node-level features will be zero-padded with "
204+ f"{ self .zero_pad_node } additional dimensions."
205+ )
206+
207+ self .zero_pad_edge = int (zero_pad_edge ) if zero_pad_edge else None
208+ if self .zero_pad_edge :
194209 print (
195- f"[Info] Atom -level features will be zero-padded with "
196- f"{ self .zero_pad_atom } additional dimensions."
210+ f"[Info] Edge -level features will be zero-padded with "
211+ f"{ self .zero_pad_edge } additional dimensions."
197212 )
198213
214+ self .random_pad_edge = int (random_pad_edge ) if random_pad_edge else None
215+ self .random_pad_node = int (random_pad_node ) if random_pad_node else None
216+ if self .random_pad_node or self .random_pad_edge :
217+ assert (
218+ distribution is not None
219+ and distribution in RandomFeatureInitializationReader .DISTRIBUTIONS
220+ ), "When using random padding, a valid distribution must be specified."
221+ self .distribution = distribution
222+ if self .random_pad_node :
223+ print (
224+ f"[Info] Node-level features will be padded with "
225+ f"{ self .random_pad_node } additional dimensions initialized from { self .distribution } distribution."
226+ )
227+ if self .random_pad_edge :
228+ print (
229+ f"[Info] Edge-level features will be padded with "
230+ f"{ self .random_pad_edge } additional dimensions initialized from { self .distribution } distribution."
231+ )
232+
199233 if self .properties :
200234 print (
201235 f"Data module uses these properties (ordered): { ', ' .join ([str (p ) for p in self .properties ])} "
@@ -242,8 +276,24 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
242276 else :
243277 raise TypeError (f"Unsupported property type: { type (property ).__name__ } " )
244278
245- if self .zero_pad_atom is not None :
246- x = torch .cat ([x , torch .zeros ((x .shape [0 ], self .zero_pad_atom ))], dim = 1 )
279+ if self .zero_pad_node :
280+ x = torch .cat ([x , torch .zeros ((x .shape [0 ], self .zero_pad_node ))], dim = 1 )
281+
282+ if self .zero_pad_edge :
283+ edge_attr = torch .cat (
284+ [edge_attr , torch .zeros ((edge_attr .shape [0 ], self .zero_pad_edge ))],
285+ dim = 1 ,
286+ )
287+
288+ if self .random_pad_node :
289+ random_pad = torch .empty ((x .shape [0 ], self .random_pad_node ))
290+ RandomFeatureInitializationReader .random_gni (random_pad , self .distribution )
291+ x = torch .cat ([x , random_pad ], dim = 1 )
292+
293+ if self .random_pad_edge :
294+ random_pad = torch .empty ((edge_attr .shape [0 ], self .random_pad_edge ))
295+ RandomFeatureInitializationReader .random_gni (random_pad , self .distribution )
296+ edge_attr = torch .cat ([edge_attr , random_pad ], dim = 1 )
247297
248298 return GeomData (
249299 x = x ,
@@ -291,18 +341,44 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
291341 prop_lengths = [
292342 (prop .name , prop .encoder .get_encoding_length ()) for prop in self .properties
293343 ]
344+
345+ # -------------------------- Count total node properties
294346 n_node_properties = sum (
295347 p .encoder .get_encoding_length ()
296348 for p in self .properties
297349 if isinstance (p , AtomProperty )
298350 )
299- if self .zero_pad_atom :
300- n_node_properties += self .zero_pad_atom
351+
352+ in_channels_str = f"in_channels: { n_node_properties } "
353+ if self .zero_pad_node :
354+ n_node_properties += self .zero_pad_node
355+ in_channels_str += f"(with { self .zero_pad_node } padded zeros)"
356+
357+ if self .random_pad_node :
358+ n_node_properties += self .random_pad_node
359+ in_channels_str += f"(with { self .random_pad_node } random padded values from { self .distribution } distribution)"
360+
361+ # -------------------------- Count total edge properties
362+ n_edge_properties = sum (
363+ p .encoder .get_encoding_length ()
364+ for p in self .properties
365+ if isinstance (p , BondProperty )
366+ )
367+ edge_dim_str = f"edge_dim: { n_edge_properties } "
368+
369+ if self .zero_pad_edge :
370+ n_edge_properties += self .zero_pad_edge
371+ edge_dim_str += f"(with { self .zero_pad_edge } padded zeros)"
372+
373+ if self .random_pad_edge :
374+ n_edge_properties += self .random_pad_edge
375+ edge_dim_str += f"(with { self .random_pad_edge } random padded values from { self .distribution } distribution)"
376+
301377 rank_zero_info (
302378 f"Finished loading dataset from properties.\n Encoding lengths: { prop_lengths } \n "
303379 f"Use following values for given parameters for model configuration: \n \t "
304- f"in_channels: { n_node_properties } (with { self . zero_pad_atom } padded zeros) , "
305- f"edge_dim: { sum ( p . encoder . get_encoding_length () for p in self . properties if isinstance ( p , BondProperty )) } , "
380+ f"{ in_channels_str } , "
381+ f"{ edge_dim_str } , "
306382 f"n_molecule_properties: { sum (p .encoder .get_encoding_length () for p in self .properties if isinstance (p , MoleculeProperty ))} "
307383 )
308384
0 commit comments