@@ -178,6 +178,17 @@ def _after_setup(self, **kwargs) -> None:
178178
179179
180180class GraphPropertiesMixIn (DataPropertiesSetter , ABC ):
181+ def __init__ (
182+ self , properties = None , transform = None , zero_pad_atom : int = None , ** kwargs
183+ ):
184+ super ().__init__ (properties , transform , ** kwargs )
185+ self .zero_pad_atom = int (zero_pad_atom ) if zero_pad_atom is not None else None
186+ if self .zero_pad_atom :
187+ print (
188+ f"[Info] Atom-level features will be zero-padded with "
189+ f"{ self .zero_pad_atom } additional dimensions."
190+ )
191+
181192 def _merge_props_into_base (self , row : pd .Series ) -> GeomData :
182193 """
183194 Merge encoded molecular properties into the GeomData object.
@@ -219,6 +230,9 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
219230 else :
220231 raise TypeError (f"Unsupported property type: { type (property ).__name__ } " )
221232
233+ if self .zero_pad_atom is not None :
234+ x = torch .cat ([x , torch .zeros ((x .shape [0 ], self .zero_pad_atom ))], dim = 1 )
235+
222236 return GeomData (
223237 x = x ,
224238 edge_index = geom_data .edge_index ,
@@ -265,10 +279,17 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
265279 prop_lengths = [
266280 (prop .name , prop .encoder .get_encoding_length ()) for prop in self .properties
267281 ]
282+ n_node_properties = sum (
283+ p .encoder .get_encoding_length ()
284+ for p in self .properties
285+ if isinstance (p , AtomProperty )
286+ )
287+ if self .zero_pad_atom :
288+ n_node_properties += self .zero_pad_atom
268289 rank_zero_info (
269290 f"Finished loading dataset from properties.\n Encoding lengths: { prop_lengths } \n "
270291 f"Use following values for given parameters for model configuration: \n \t "
271- f"in_channels: { sum ( p . encoder . get_encoding_length () for p in self .properties if isinstance ( p , AtomProperty )) } , "
292+ f"in_channels: { n_node_properties } (with { self .zero_pad_atom } padded zeros) , "
272293 f"edge_dim: { sum (p .encoder .get_encoding_length () for p in self .properties if isinstance (p , BondProperty ))} , "
273294 f"n_molecule_properties: { sum (p .encoder .get_encoding_length () for p in self .properties if isinstance (p , MoleculeProperty ))} "
274295 )
0 commit comments