Skip to content

Commit 4fb00e6

Browse files
committed
assert statements for num of edges and nodes
1 parent 6bc4f35 commit 4fb00e6

File tree

1 file changed

+36
-25
lines changed

1 file changed

+36
-25
lines changed

chebai_graph/preprocessing/reader/augmented_reader.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,10 @@ def __init__(self, *args, **kwargs):
3535
**kwargs: Additional keyword arguments passed to the ChemDataReader.
3636
"""
3737
super().__init__(*args, **kwargs)
38-
self.f_cnt_for_smiles = (
39-
0 # Record number of failures when constructing molecule from smiles
40-
)
41-
self.f_cnt_for_aug_graph = (
42-
0 # Record number of failure during augmented graph construction
43-
)
38+
# Record number of failures when constructing molecule from smiles
39+
self.f_cnt_for_smiles = 0
40+
# Record number of failure during augmented graph construction
41+
self.f_cnt_for_aug_graph = 0
4442
self.mol_object_buffer = {}
4543
self._num_of_nodes = 0
4644
self._num_of_edges = 0
@@ -245,15 +243,15 @@ def _augment_graph_structure(
245243
if returned_result is None:
246244
return None
247245

248-
fg_atom_edge_index, fg_nodes, atom_fg_edges, structured_fg_map, bonds = (
246+
fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, bonds = (
249247
returned_result
250248
)
251249

252250
fg_internal_edge_index, internal_fg_edges = self._construct_fg_level_structure(
253-
structured_fg_map, bonds
251+
fg_to_atoms_map, bonds
254252
)
255253
fg_graph_edge_index, graph_node, fg_to_graph_edges = (
256-
self._construct_fg_to_graph_node_structure(structured_fg_map)
254+
self._construct_fg_to_graph_node_structure(fg_to_atoms_map)
257255
)
258256

259257
# Merge all edge types
@@ -272,20 +270,35 @@ def _augment_graph_structure(
272270
[directed_edge_index, directed_edge_index[[1, 0], :]], dim=1
273271
)
274272

273+
total_atoms = sum([mol.GetNumAtoms(), len(fg_nodes), 1])
274+
assert (
275+
self._num_of_nodes == total_atoms
276+
), f"Mismatch in number of nodes: expected {total_atoms}, got {self._num_of_nodes}"
275277
node_info = {
276278
"atom_nodes": mol,
277279
"fg_nodes": fg_nodes,
278280
"graph_node": graph_node,
279281
"num_nodes": self._num_of_nodes,
280282
}
283+
284+
total_edges = sum(
285+
[
286+
mol.GetNumBonds(),
287+
len(atom_fg_edges),
288+
len(internal_fg_edges),
289+
len(fg_to_graph_edges),
290+
]
291+
)
292+
assert (
293+
self._num_of_edges == total_edges
294+
), f"Mismatch in number of edges: expected {total_edges}, got {self._num_of_edges}"
281295
edge_info = {
282296
WITHIN_ATOMS_EDGE: mol,
283297
ATOM_FG_EDGE: atom_fg_edges,
284298
WITHIN_FG_EDGE: internal_fg_edges,
285299
FG_GRAPHNODE_EDGE: fg_to_graph_edges,
286-
"num_edges": self._num_of_edges * 2, # Undirected edges
300+
"num_undirected_edges": self._num_of_edges * 2, # Undirected edges
287301
}
288-
289302
return undirected_edge_index, node_info, edge_info
290303

291304
@staticmethod
@@ -342,12 +355,11 @@ def _construct_fg_to_atom_structure(
342355

343356
fg_atom_edge_index = [[], []]
344357
fg_nodes, atom_fg_edges = {}, {}
345-
structured_fg_map = (
346-
{}
347-
) # Contains augmented fg-nodes and connected atoms indices
358+
# Contains augmented fg-nodes and connected atoms indices
359+
fg_to_atoms_map = {}
348360

349361
for fg_group in structure.values():
350-
structured_fg_map[self._num_of_nodes] = {"atom": fg_group["atom"]}
362+
fg_to_atoms_map[self._num_of_nodes] = {"atom": fg_group["atom"]}
351363

352364
# Build edge index for fg to atom nodes connections
353365
for atom_idx in fg_group["atom"]:
@@ -370,9 +382,8 @@ def _construct_fg_to_atom_structure(
370382
"A functional group must not span multiple ring sizes."
371383
)
372384

373-
if (
374-
len(ring_fg) == 1
375-
): # FG atoms have ring size, which indicates the FG is a Ring or Fused Rings
385+
if len(ring_fg) == 1:
386+
# FG atoms have ring size, which indicates the FG is a Ring or Fused Rings
376387
ring_size = next(iter(ring_fg))
377388
fg_nodes[self._num_of_nodes] = {
378389
NODE_LEVEL: FG_NODE_LEVEL,
@@ -406,16 +417,16 @@ def _construct_fg_to_atom_structure(
406417

407418
self._num_of_nodes += 1
408419

409-
return fg_atom_edge_index, fg_nodes, atom_fg_edges, structured_fg_map, bonds
420+
return fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, bonds
410421

411422
def _construct_fg_level_structure(
412-
self, structured_fg_map: dict, bonds: list
423+
self, fg_to_atoms_map: dict, bonds: list
413424
) -> Tuple[List[List[int]], dict]:
414425
"""
415426
Constructs internal edges between functional group nodes based on bond connections.
416427
417428
Args:
418-
structured_fg_map (dict): Mapping from FG ID to atom indices.
429+
fg_to_atoms_map (dict): Mapping from FG ID to atom indices.
419430
bonds (list): List of bond tuples (source, target, ...).
420431
421432
Returns:
@@ -428,7 +439,7 @@ def _construct_fg_level_structure(
428439
source_atom, target_atom = bond[:2]
429440
source_fg, target_fg = None, None
430441

431-
for fg_id, data in structured_fg_map.items():
442+
for fg_id, data in fg_to_atoms_map.items():
432443
if source_atom in data["atom"]:
433444
source_fg = fg_id
434445
if target_atom in data["atom"]:
@@ -450,13 +461,13 @@ def _construct_fg_level_structure(
450461
return internal_edge_index, internal_fg_edges
451462

452463
def _construct_fg_to_graph_node_structure(
453-
self, structured_fg_map: dict
464+
self, fg_to_atoms_map: dict
454465
) -> Tuple[List[List[int]], dict, dict]:
455466
"""
456467
Constructs edges between functional group nodes and a global graph-level node.
457468
458469
Args:
459-
structured_fg_map (dict): Mapping from FG ID to atom indices.
470+
fg_to_atoms_map (dict): Mapping from FG ID to atom indices.
460471
461472
Returns:
462473
Tuple[List[List[int]], dict, dict]: Edge index, graph-level node, edge attributes.
@@ -466,7 +477,7 @@ def _construct_fg_to_graph_node_structure(
466477
fg_graph_edges = {}
467478
graph_edge_index = [[], []]
468479

469-
for fg_id in structured_fg_map:
480+
for fg_id in fg_to_atoms_map:
470481
graph_edge_index[0].append(self._num_of_nodes)
471482
graph_edge_index[1].append(fg_id)
472483
fg_graph_edges[f"{self._num_of_nodes}_{fg_id}"] = {

0 commit comments

Comments
 (0)