1515from torch_geometric .data .data import Data as GeomData
1616
1717from chebai_graph .preprocessing .properties import (
18+ AllNodeTypeProperty ,
19+ AtomNodeTypeProperty ,
1820 AtomProperty ,
1921 BondProperty ,
22+ FGNodeTypeProperty ,
2023 MolecularProperty ,
24+ MoleculeProperty ,
2125)
2226from chebai_graph .preprocessing .reader import (
2327 AtomFGReader_NoFGEdges_WithGraphNode ,
@@ -41,7 +45,7 @@ def __init__(self, **kwargs):
4145 super ().__init__ (** kwargs )
4246
4347
44- class GraphPropertiesMixIn (ChEBIOverX , ABC ):
48+ class DataPropertiesSetter (ChEBIOverX , ABC ):
4549 """Mixin for adding molecular property encodings to graph-based ChEBI datasets."""
4650
4751 READER = GraphPropertyReader
@@ -172,6 +176,8 @@ def _after_setup(self, **kwargs) -> None:
172176 self ._setup_properties ()
173177 super ()._after_setup (** kwargs )
174178
179+
180+ class GraphPropertiesMixIn (DataPropertiesSetter , ABC ):
175181 def _merge_props_into_base (self , row : pd .Series ) -> GeomData :
176182 """
177183 Merge encoded molecular properties into the GeomData object.
@@ -208,8 +214,10 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
208214 [edge_attr , torch .cat ([property_values , property_values ], dim = 0 )],
209215 dim = 1 ,
210216 )
211- else :
217+ elif isinstance ( property , MoleculeProperty ) :
212218 molecule_attr = torch .cat ([molecule_attr , property_values ], dim = 1 )
219+ else :
220+ raise TypeError (f"Unsupported property type: { type (property ).__name__ } " )
213221
214222 return GeomData (
215223 x = x ,
@@ -261,11 +269,153 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
261269 f"Finished loading dataset from properties.\n Encoding lengths: { prop_lengths } \n "
262270 f"Use n_atom_properties: { sum (p .encoder .get_encoding_length () for p in self .properties if isinstance (p , AtomProperty ))} , "
263271 f"n_bond_properties: { sum (p .encoder .get_encoding_length () for p in self .properties if isinstance (p , BondProperty ))} , "
264- f"n_molecule_properties: { sum (p .encoder .get_encoding_length () for p in self .properties if not isinstance (p , (AtomProperty , BondProperty )))} "
272+ f"n_molecule_properties: { sum (p .encoder .get_encoding_length () for p in self .properties if isinstance (p , MoleculeProperty ))} "
273+ )
274+
275+ return base_df [base_data [0 ].keys ()].to_dict ("records" )
276+
277+
278+ class GraphPropertiesAsPerNodeType (DataPropertiesSetter , ABC ):
279+ READER = AtomFGReader_WithFGEdges_WithGraphNode
280+
281+ def load_processed_data_from_file (self , filename : str ) -> list [dict ]:
282+ """
283+ Load dataset and merge cached properties into base features.
284+
285+ Args:
286+ filename: The path to the file to load.
287+
288+ Returns:
289+ List of data entries, each a dictionary.
290+ """
291+ base_data = super ().load_processed_data_from_file (filename )
292+ base_df = pd .DataFrame (base_data )
293+
294+ props_categories = {
295+ "AllNodeTypeProperties" : [],
296+ "FGNodeTypeProperties" : [],
297+ "AtomNodeTypeProperties" : [],
298+ "GraphNodeTypeProperties" : [],
299+ "BondProperties" : [],
300+ }
301+ n_atom_node_properties , n_fg_node_properties = 0 , 0
302+ n_bond_properties , n_graph_node_properties = 0 , 0
303+ prop_lengths = []
304+ for prop in self .properties :
305+ prop_length = prop .encoder .get_encoding_length ()
306+ prop_name = prop .name
307+ prop_lengths .append ((prop_name , prop_length ))
308+ if isinstance (prop , AllNodeTypeProperty ):
309+ n_atom_node_properties += prop_length
310+ n_fg_node_properties += prop_length
311+ props_categories ["AllNodeTypeProperties" ].append (prop_name )
312+ elif isinstance (prop , FGNodeTypeProperty ):
313+ n_fg_node_properties += prop_length
314+ props_categories ["FGNodeTypeProperties" ].append (prop_name )
315+ elif isinstance (prop , AtomNodeTypeProperty ):
316+ n_atom_node_properties += prop_length
317+ props_categories ["AtomNodeTypeProperties" ].append (prop_name )
318+ elif isinstance (prop , BondProperty ):
319+ n_bond_properties += prop_length
320+ props_categories ["BondProperties" ].append (prop_name )
321+ elif isinstance (prop , MoleculeProperty ):
322+ # molecule props will be used as graph node props
323+ n_graph_node_properties += prop_length
324+ props_categories ["GraphNodeTypeProperties" ].append (prop_name )
325+ else :
326+ raise TypeError (f"Unsupported property type: { type (prop ).__name__ } " )
327+
328+ n_atom_properties = max (
329+ n_atom_node_properties , n_fg_node_properties , n_graph_node_properties
265330 )
331+ rank_zero_info (
332+ f"Finished loading dataset from properties.\n Encoding lengths: { prop_lengths } \n "
333+ f"Properties Categories { props_categories } \n "
334+ f"n_atom_node_properties: { n_atom_node_properties } , "
335+ f"n_fg_node_properties: { n_fg_node_properties } , "
336+ f"n_bond_properties: { n_bond_properties } , "
337+ f"n_graph_node_properties: { n_graph_node_properties } \n "
338+ f"Use n_atom_properties: { n_atom_properties } , n_bond_properties: { n_bond_properties } , n_molecule_properties: 0"
339+ )
340+
341+ for property in self .properties :
342+ property_data = torch .load (
343+ self .get_property_path (property ), weights_only = False
344+ )
345+ if len (property_data [0 ][property .name ].shape ) > 1 :
346+ property .encoder .set_encoding_length (
347+ property_data [0 ][property .name ].shape [1 ]
348+ )
349+
350+ property_df = pd .DataFrame (property_data )
351+ property_df .rename (
352+ columns = {property .name : f"{ property .name } " }, inplace = True
353+ )
354+ base_df = base_df .merge (property_df , on = "ident" , how = "left" )
355+
356+ base_df ["features" ] = base_df .apply (
357+ lambda row : self ._merge_props_into_base (row ), axis = 1
358+ )
359+
360+ # apply transformation, e.g. masking for pretraining task
361+ if self .transform is not None :
362+ base_df ["features" ] = base_df ["features" ].apply (self .transform )
266363
267364 return base_df [base_data [0 ].keys ()].to_dict ("records" )
268365
366+ def _merge_props_into_base (self , row : pd .Series ) -> GeomData :
367+ """
368+ Merge encoded molecular properties into the GeomData object.
369+
370+ Args:
371+ row: A dictionary containing 'features' and encoded properties.
372+
373+ Returns:
374+ A GeomData object with merged features.
375+ """
376+ geom_data = row ["features" ]
377+ assert isinstance (geom_data , GeomData )
378+ is_atom_node = geom_data .is_atom_node
379+ assert is_atom_node is not None , "`is_atom_node` must be set in the geom_data"
380+ is_graph_node = geom_data .is_graph_node
381+ assert is_graph_node is not None , "`is_graph_node` must be set in the geom_data"
382+
383+ edge_attr = geom_data .edge_attr
384+ x = geom_data .x
385+ molecule_attr = torch .empty ((1 , 0 ))
386+
387+ for property in self .properties :
388+ property_values = row [f"{ property .name } " ]
389+ if isinstance (property_values , torch .Tensor ):
390+ if len (property_values .size ()) == 0 :
391+ property_values = property_values .unsqueeze (0 )
392+ if len (property_values .size ()) == 1 :
393+ property_values = property_values .unsqueeze (1 )
394+ else :
395+ property_values = torch .zeros (
396+ (0 , property .encoder .get_encoding_length ())
397+ )
398+
399+ if isinstance (property , AtomProperty ):
400+ x = torch .cat ([x , property_values ], dim = 1 )
401+ elif isinstance (property , BondProperty ):
402+ # Concat/Duplicate properties values for undirected graph as `edge_index` has first src to tgt edges, then tgt to src edges
403+ edge_attr = torch .cat (
404+ [edge_attr , torch .cat ([property_values , property_values ], dim = 0 )],
405+ dim = 1 ,
406+ )
407+ elif isinstance (property , MoleculeProperty ):
408+ molecule_attr = torch .cat ([molecule_attr , property_values ], dim = 1 )
409+ else :
410+ raise TypeError (f"Unsupported property type: { type (property ).__name__ } " )
411+
412+ return GeomData (
413+ x = x ,
414+ edge_index = geom_data .edge_index ,
415+ edge_attr = edge_attr ,
416+ molecule_attr = molecule_attr ,
417+ )
418+
269419
270420class ChEBI50GraphProperties (GraphPropertiesMixIn , ChEBIOver50 ):
271421 """ChEBIOver50 dataset with molecular property encodings."""
@@ -310,7 +460,7 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
310460 data = super ()._merge_props_into_base (row )
311461 return self ._add_graph_node_mask (data , row )
312462
313- def _add_graph_node_mask (self , data : GeomData , row ) -> GeomData :
463+ def _add_graph_node_mask (self , data : GeomData , row : pd . Series ) -> GeomData :
314464 """
315465 Add a graph node mask to the GeomData object.
316466
0 commit comments