@@ -35,12 +35,10 @@ def __init__(self, *args, **kwargs):
3535 **kwargs: Additional keyword arguments passed to the ChemDataReader.
3636 """
3737 super ().__init__ (* args , ** kwargs )
38- self .f_cnt_for_smiles = (
39- 0 # Record number of failures when constructing molecule from smiles
40- )
41- self .f_cnt_for_aug_graph = (
42- 0 # Record number of failure during augmented graph construction
43- )
38+ # Record number of failures when constructing molecule from smiles
39+ self .f_cnt_for_smiles = 0
40+ # Record number of failure during augmented graph construction
41+ self .f_cnt_for_aug_graph = 0
4442 self .mol_object_buffer = {}
4543 self ._num_of_nodes = 0
4644 self ._num_of_edges = 0
@@ -245,15 +243,15 @@ def _augment_graph_structure(
245243 if returned_result is None :
246244 return None
247245
248- fg_atom_edge_index , fg_nodes , atom_fg_edges , structured_fg_map , bonds = (
246+ fg_atom_edge_index , fg_nodes , atom_fg_edges , fg_to_atoms_map , bonds = (
249247 returned_result
250248 )
251249
252250 fg_internal_edge_index , internal_fg_edges = self ._construct_fg_level_structure (
253- structured_fg_map , bonds
251+ fg_to_atoms_map , bonds
254252 )
255253 fg_graph_edge_index , graph_node , fg_to_graph_edges = (
256- self ._construct_fg_to_graph_node_structure (structured_fg_map )
254+ self ._construct_fg_to_graph_node_structure (fg_to_atoms_map )
257255 )
258256
259257 # Merge all edge types
@@ -272,20 +270,35 @@ def _augment_graph_structure(
272270 [directed_edge_index , directed_edge_index [[1 , 0 ], :]], dim = 1
273271 )
274272
273+ total_atoms = sum ([mol .GetNumAtoms (), len (fg_nodes ), 1 ])
274+ assert (
275+ self ._num_of_nodes == total_atoms
276+ ), f"Mismatch in number of nodes: expected { total_atoms } , got { self ._num_of_nodes } "
275277 node_info = {
276278 "atom_nodes" : mol ,
277279 "fg_nodes" : fg_nodes ,
278280 "graph_node" : graph_node ,
279281 "num_nodes" : self ._num_of_nodes ,
280282 }
283+
284+ total_edges = sum (
285+ [
286+ mol .GetNumBonds (),
287+ len (atom_fg_edges ),
288+ len (internal_fg_edges ),
289+ len (fg_to_graph_edges ),
290+ ]
291+ )
292+ assert (
293+ self ._num_of_edges == total_edges
294+ ), f"Mismatch in number of edges: expected { total_edges } , got { self ._num_of_edges } "
281295 edge_info = {
282296 WITHIN_ATOMS_EDGE : mol ,
283297 ATOM_FG_EDGE : atom_fg_edges ,
284298 WITHIN_FG_EDGE : internal_fg_edges ,
285299 FG_GRAPHNODE_EDGE : fg_to_graph_edges ,
286- "num_edges " : self ._num_of_edges * 2 , # Undirected edges
300+ "num_undirected_edges " : self ._num_of_edges * 2 , # Undirected edges
287301 }
288-
289302 return undirected_edge_index , node_info , edge_info
290303
291304 @staticmethod
@@ -342,12 +355,11 @@ def _construct_fg_to_atom_structure(
342355
343356 fg_atom_edge_index = [[], []]
344357 fg_nodes , atom_fg_edges = {}, {}
345- structured_fg_map = (
346- {}
347- ) # Contains augmented fg-nodes and connected atoms indices
358+ # Contains augmented fg-nodes and connected atoms indices
359+ fg_to_atoms_map = {}
348360
349361 for fg_group in structure .values ():
350- structured_fg_map [self ._num_of_nodes ] = {"atom" : fg_group ["atom" ]}
362+ fg_to_atoms_map [self ._num_of_nodes ] = {"atom" : fg_group ["atom" ]}
351363
352364 # Build edge index for fg to atom nodes connections
353365 for atom_idx in fg_group ["atom" ]:
@@ -370,9 +382,8 @@ def _construct_fg_to_atom_structure(
370382 "A functional group must not span multiple ring sizes."
371383 )
372384
373- if (
374- len (ring_fg ) == 1
375- ): # FG atoms have ring size, which indicates the FG is a Ring or Fused Rings
385+ if len (ring_fg ) == 1 :
386+ # FG atoms have ring size, which indicates the FG is a Ring or Fused Rings
376387 ring_size = next (iter (ring_fg ))
377388 fg_nodes [self ._num_of_nodes ] = {
378389 NODE_LEVEL : FG_NODE_LEVEL ,
@@ -406,16 +417,16 @@ def _construct_fg_to_atom_structure(
406417
407418 self ._num_of_nodes += 1
408419
409- return fg_atom_edge_index , fg_nodes , atom_fg_edges , structured_fg_map , bonds
420+ return fg_atom_edge_index , fg_nodes , atom_fg_edges , fg_to_atoms_map , bonds
410421
411422 def _construct_fg_level_structure (
412- self , structured_fg_map : dict , bonds : list
423+ self , fg_to_atoms_map : dict , bonds : list
413424 ) -> Tuple [List [List [int ]], dict ]:
414425 """
415426 Constructs internal edges between functional group nodes based on bond connections.
416427
417428 Args:
418- structured_fg_map (dict): Mapping from FG ID to atom indices.
429+ fg_to_atoms_map (dict): Mapping from FG ID to atom indices.
419430 bonds (list): List of bond tuples (source, target, ...).
420431
421432 Returns:
@@ -428,7 +439,7 @@ def _construct_fg_level_structure(
428439 source_atom , target_atom = bond [:2 ]
429440 source_fg , target_fg = None , None
430441
431- for fg_id , data in structured_fg_map .items ():
442+ for fg_id , data in fg_to_atoms_map .items ():
432443 if source_atom in data ["atom" ]:
433444 source_fg = fg_id
434445 if target_atom in data ["atom" ]:
@@ -450,13 +461,13 @@ def _construct_fg_level_structure(
450461 return internal_edge_index , internal_fg_edges
451462
452463 def _construct_fg_to_graph_node_structure (
453- self , structured_fg_map : dict
464+ self , fg_to_atoms_map : dict
454465 ) -> Tuple [List [List [int ]], dict , dict ]:
455466 """
456467 Constructs edges between functional group nodes and a global graph-level node.
457468
458469 Args:
459- structured_fg_map (dict): Mapping from FG ID to atom indices.
470+ fg_to_atoms_map (dict): Mapping from FG ID to atom indices.
460471
461472 Returns:
462473 Tuple[List[List[int]], dict, dict]: Edge index, graph-level node, edge attributes.
@@ -466,7 +477,7 @@ def _construct_fg_to_graph_node_structure(
466477 fg_graph_edges = {}
467478 graph_edge_index = [[], []]
468479
469- for fg_id in structured_fg_map :
480+ for fg_id in fg_to_atoms_map :
470481 graph_edge_index [0 ].append (self ._num_of_nodes )
471482 graph_edge_index [1 ].append (fg_id )
472483 fg_graph_edges [f"{ self ._num_of_nodes } _{ fg_id } " ] = {
0 commit comments