@@ -275,8 +275,19 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
275275 return base_df [base_data [0 ].keys ()].to_dict ("records" )
276276
277277
278- class GraphPropertiesAsPerNodeType (DataPropertiesSetter , ABC ):
279- READER = AtomFGReader_WithFGEdges_WithGraphNode
278+ class GraphPropAsPerNodeType (DataPropertiesSetter , ABC ):
279+ def __init__ (self , properties = None , transform = None , ** kwargs ):
280+ super ().__init__ (properties , transform , ** kwargs )
281+ # Sort properties so that AllNodeTypeProperty instances come first, rest of the properties order remain same
282+ first = [
283+ prop for prop in self .properties if isinstance (prop , AllNodeTypeProperty )
284+ ]
285+ rest = [
286+ prop
287+ for prop in self .properties
288+ if not isinstance (prop , AllNodeTypeProperty )
289+ ]
290+ self .properties = first + rest
280291
281292 def load_processed_data_from_file (self , filename : str ) -> list [dict ]:
282293 """
@@ -308,6 +319,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
308319 if isinstance (prop , AllNodeTypeProperty ):
309320 n_atom_node_properties += prop_length
310321 n_fg_node_properties += prop_length
322+ n_graph_node_properties += prop_length
311323 props_categories ["AllNodeTypeProperties" ].append (prop_name )
312324 elif isinstance (prop , FGNodeTypeProperty ):
313325 n_fg_node_properties += prop_length
@@ -354,7 +366,11 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
354366 base_df = base_df .merge (property_df , on = "ident" , how = "left" )
355367
356368 base_df ["features" ] = base_df .apply (
357- lambda row : self ._merge_props_into_base (row ), axis = 1
369+ lambda row : self ._merge_props_into_base (
370+ row ,
371+ max_len_node_properties = n_atom_properties ,
372+ ),
373+ axis = 1 ,
358374 )
359375
360376 # apply transformation, e.g. masking for pretraining task
@@ -363,7 +379,9 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
363379
364380 return base_df [base_data [0 ].keys ()].to_dict ("records" )
365381
366- def _merge_props_into_base (self , row : pd .Series ) -> GeomData :
382+ def _merge_props_into_base (
383+ self , row : pd .Series , max_len_node_properties : int
384+ ) -> GeomData :
367385 """
368386 Merge encoded molecular properties into the GeomData object.
369387
@@ -375,14 +393,24 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
375393 """
376394 geom_data = row ["features" ]
377395 assert isinstance (geom_data , GeomData )
396+
378397 is_atom_node = geom_data .is_atom_node
379398 assert is_atom_node is not None , "`is_atom_node` must be set in the geom_data"
380399 is_graph_node = geom_data .is_graph_node
381400 assert is_graph_node is not None , "`is_graph_node` must be set in the geom_data"
382401
402+ is_fg_node = ~ is_atom_node & ~ is_graph_node
403+ num_nodes = geom_data .x .size (0 )
383404 edge_attr = geom_data .edge_attr
384- x = geom_data .x
385- molecule_attr = torch .empty ((1 , 0 ))
405+
406+ # Initialize node feature matrix
407+ assert (
408+ max_len_node_properties is not None
409+ ), "Maximum len of node properties should not be None"
410+ x = torch .zeros ((num_nodes , max_len_node_properties ))
411+
412+ # Track column offsets for each node type
413+ atom_offset , fg_offset , graph_offset = 0 , 0 , 0
386414
387415 for property in self .properties :
388416 property_values = row [f"{ property .name } " ]
@@ -396,24 +424,51 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
396424 (0 , property .encoder .get_encoding_length ())
397425 )
398426
399- if isinstance (property , AtomProperty ):
400- x = torch .cat ([x , property_values ], dim = 1 )
427+ enc_len = property_values .shape [1 ]
428+ # -------------- Node properties ---------------
429+ if isinstance (property , AllNodeTypeProperty ):
430+ x [:, atom_offset : atom_offset + enc_len ] = property_values
431+ atom_offset += enc_len
432+ fg_offset += enc_len
433+ graph_offset += enc_len
434+
435+ elif isinstance (property , AtomNodeTypeProperty ):
436+ x [is_atom_node , atom_offset : atom_offset + enc_len ] = property_values [
437+ is_atom_node
438+ ]
439+ atom_offset += enc_len
440+
441+ elif isinstance (property , FGNodeTypeProperty ):
442+ x [is_fg_node , fg_offset : fg_offset + enc_len ] = property_values [
443+ is_fg_node
444+ ]
445+ fg_offset += enc_len
446+
447+ elif isinstance (property , MoleculeProperty ):
448+ x [is_graph_node , graph_offset : graph_offset + enc_len ] = (
449+ property_values [is_graph_node ]
450+ )
451+ graph_offset += enc_len
452+
453+ # ------------- Bond Properties --------------
401454 elif isinstance (property , BondProperty ):
402455 # Concat/Duplicate properties values for undirected graph as `edge_index` has first src to tgt edges, then tgt to src edges
403456 edge_attr = torch .cat (
404457 [edge_attr , torch .cat ([property_values , property_values ], dim = 0 )],
405458 dim = 1 ,
406459 )
407- elif isinstance (property , MoleculeProperty ):
408- molecule_attr = torch .cat ([molecule_attr , property_values ], dim = 1 )
409460 else :
410461 raise TypeError (f"Unsupported property type: { type (property ).__name__ } " )
411462
463+ total_used_columns = max (atom_offset , fg_offset , graph_offset )
464+ assert (
465+ total_used_columns <= max_len_node_properties
466+ ), f"Used { total_used_columns } columns, but max allowed is { max_len_node_properties } "
467+
412468 return GeomData (
413469 x = x ,
414470 edge_index = geom_data .edge_index ,
415471 edge_attr = edge_attr ,
416- molecule_attr = molecule_attr ,
417472 )
418473
419474
@@ -507,3 +562,7 @@ class ChEBI50_Atom_WGNOnly_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver5
507562 """ChEBIOver50 with atom-level nodes and graph node only."""
508563
509564 READER = AtomReader_WithGraphNodeOnly
565+
566+
567+ class ChEBI50_WFGE_WGN_AsPerNodeType (GraphPropAsPerNodeType , ChEBIOver50 ):
568+ READER = AtomFGReader_WithFGEdges_WithGraphNode
0 commit comments