Skip to content

Commit 767b210

Browse files
committed
scope: fix for true values less given threshold for some labels
1 parent 6d7b467 commit 767b210

File tree

1 file changed

+72
-27
lines changed
  • chebai/preprocessing/datasets/scope

1 file changed

+72
-27
lines changed

chebai/preprocessing/datasets/scope/scope.py

Lines changed: 72 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -198,49 +198,91 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
198198
pdb_chain_df = self._parse_pdb_sequence_file()
199199
pdb_id_set = set(pdb_chain_df["pdb_id"]) # Search time complexity - O(1)
200200

201-
g = nx.DiGraph()
202-
203-
edges = []
201+
# Initialize sets and dictionaries for storing edges and attributes
202+
parent_node_edges, node_child_edges = set(), set()
204203
node_attrs = {}
205204
px_level_nodes = set()
205+
sequence_nodes = dict()
206+
px_to_seq_edges = set()
207+
required_graph_nodes = set()
208+
209+
# Create a lookup dictionary for PDB chain sequences
210+
lookup_dict = (
211+
pdb_chain_df.groupby("pdb_id")[["chain_id", "sequence"]]
212+
.apply(lambda x: dict(zip(x["chain_id"], x["sequence"])))
213+
.to_dict()
214+
)
215+
216+
def add_sequence_nodes_edges(chain_sequence, px_sun_id):
217+
"""Adds sequence nodes and edges connecting px-level nodes to sequence nodes."""
218+
if chain_sequence not in sequence_nodes:
219+
sequence_nodes[chain_sequence] = f"seq_{len(sequence_nodes)}"
220+
px_to_seq_edges.add((px_sun_id, sequence_nodes[chain_sequence]))
206221

207-
# Step 1: Build the graph and store attributes
222+
# Step 1: Build the graph structure and store node attributes
208223
for row in df_scope.itertuples(index=False):
209224
if row.level == "px":
210-
if row.sid[1:5] not in pdb_id_set:
225+
226+
pdb_id, chain_id = row.sid[1:5], row.sid[5]
227+
228+
if pdb_id not in pdb_id_set or chain_id == "_":
211229
# Don't add domain level nodes that don't have pdb_id in pdb_sequences.txt file
230+
# Also chain_id with "_" which corresponds to no chain
212231
continue
213232
px_level_nodes.add(row.sunid)
214233

234+
# Add edges between px-level nodes and sequence nodes
235+
if chain_id != ".":
236+
if chain_id not in lookup_dict[pdb_id]:
237+
continue
238+
add_sequence_nodes_edges(lookup_dict[pdb_id][chain_id], row.sunid)
239+
else:
240+
# If chain_id is '.', connect all chains of this PDB ID
241+
for chain, chain_sequence in lookup_dict[pdb_id].items():
242+
add_sequence_nodes_edges(chain_sequence, row.sunid)
243+
else:
244+
required_graph_nodes.add(row.sunid)
245+
215246
node_attrs[row.sunid] = {"sid": row.sid, "level": row.level}
216247

217248
if row.parent_sunid != -1:
218-
edges.append((row.parent_sunid, row.sunid))
249+
parent_node_edges.add((row.parent_sunid, row.sunid))
219250

220251
for child_id in row.children_sunids:
221-
edges.append((row.sunid, child_id))
252+
node_child_edges.add((row.sunid, child_id))
222253

223-
g.add_nodes_from((node, attrs) for node, attrs in node_attrs.items())
224-
g.add_edges_from(edges)
254+
del df_scope, pdb_chain_df, pdb_id_set
225255

226-
# Step 2: Compute the transitive closure first
227-
print("Computing transitive closure")
228-
g_tc = nx.transitive_closure_dag(g)
229-
230-
print(
231-
"Remove node without domain descendants that don't have pdb correspondence"
256+
g = nx.DiGraph()
257+
g.add_nodes_from(node_attrs.items())
258+
# Note - `add_edges` internally create a node, if a node doesn't exist already
259+
g.add_edges_from({(p, c) for p, c in parent_node_edges if p in node_attrs})
260+
g.add_edges_from({(p, c) for p, c in node_child_edges if c in node_attrs})
261+
262+
seq_nodes = set(sequence_nodes.values())
263+
g.add_nodes_from([(seq_id, {"level": "sequence"}) for seq_id in seq_nodes])
264+
g.add_edges_from(
265+
{
266+
(px_node, seq_node)
267+
for px_node, seq_node in px_to_seq_edges
268+
if px_node in node_attrs and seq_node in seq_nodes
269+
}
232270
)
233-
# Step 3: Identify and remove nodes that don’t have a "px" descendant with correspondence to pdb_sequences file
234-
nodes_to_remove = set()
235-
for node in g_tc.nodes:
236-
if node not in px_level_nodes and not any(
237-
desc in px_level_nodes for desc in g_tc.successors(node)
238-
):
239-
nodes_to_remove.add(node)
240271

241-
g_tc.remove_nodes_from(nodes_to_remove)
272+
# Step 2: Count sequence successors for required graph nodes only
273+
for node in required_graph_nodes:
274+
num_seq_successors = sum(
275+
g.nodes[child]["level"] == "sequence"
276+
for child in nx.descendants(g, node)
277+
)
278+
g.nodes[node]["num_seq_successors"] = num_seq_successors
242279

243-
return g_tc
280+
# Step 3: Remove nodes which are not required before computing transitive closure for better efficiency
281+
g.remove_nodes_from(px_level_nodes | seq_nodes)
282+
283+
print("Computing Transitive Closure.........")
284+
# Transitive closure is not needed in `select_classes` method but is required in _SCOPeOverXPartial
285+
return nx.transitive_closure_dag(g)
244286

245287
def _get_scope_data(self) -> pd.DataFrame:
246288
"""
@@ -808,12 +850,15 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]]
808850
"""
809851
selected_sunids_for_level = {}
810852
for node, attr_dict in g.nodes(data=True):
811-
if g.out_degree(node) >= self.THRESHOLD:
853+
if attr_dict["level"] in {"root", "px", "sequence"}:
854+
# Skip nodes with level "root", "px", or "sequence"
855+
continue
856+
857+
# Check if the number of "sequence"-level successors meets or exceeds the threshold
858+
if g.nodes[node]["num_seq_successors"] >= self.THRESHOLD:
812859
selected_sunids_for_level.setdefault(attr_dict["level"], []).append(
813860
node
814861
)
815-
# Remove root node, as it will True for all instances
816-
selected_sunids_for_level.pop("root", None)
817862
return selected_sunids_for_level
818863

819864

0 commit comments

Comments
 (0)