Skip to content

Commit 93c7fc5

Browse files
committed
scope: fix for no True labels for some classes/columns
1 parent 36e6162 commit 93c7fc5

File tree

1 file changed

+48
-16
lines changed
  • chebai/preprocessing/datasets/scope

1 file changed

+48
-16
lines changed

chebai/preprocessing/datasets/scope/scope.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -195,22 +195,52 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
195195
"""
196196
print("Extracting class hierarchy...")
197197
df_scope = self._get_scope_data()
198+
pdb_chain_df = self._parse_pdb_sequence_file()
199+
pdb_id_set = set(pdb_chain_df["pdb_id"]) # Search time complexity - O(1)
198200

199201
g = nx.DiGraph()
200202

201-
egdes = []
202-
for _, row in df_scope.iterrows():
203-
g.add_node(row["sunid"], **{"sid": row["sid"], "level": row["level"]})
204-
if row["parent_sunid"] != -1:
205-
egdes.append((row["parent_sunid"], row["sunid"]))
203+
edges = []
204+
node_attrs = {}
205+
px_level_nodes = set()
206+
207+
# Step 1: Build the graph and store attributes
208+
for row in df_scope.itertuples(index=False):
209+
if row.level == "px":
210+
if row.sid[1:5] not in pdb_id_set:
211+
# Don't add domain level nodes that don't have pdb_id in pdb_sequences.txt file
212+
continue
213+
px_level_nodes.add(row.sunid)
214+
215+
node_attrs[row.sunid] = {"sid": row.sid, "level": row.level}
216+
217+
if row.parent_sunid != -1:
218+
edges.append((row.parent_sunid, row.sunid))
206219

207-
for children_id in row["children_sunids"]:
208-
egdes.append((row["sunid"], children_id))
220+
for child_id in row.children_sunids:
221+
edges.append((row.sunid, child_id))
209222

210-
g.add_edges_from(egdes)
223+
g.add_nodes_from((node, attrs) for node, attrs in node_attrs.items())
224+
g.add_edges_from(edges)
211225

226+
# Step 2: Compute the transitive closure first
212227
print("Computing transitive closure")
213-
return nx.transitive_closure_dag(g)
228+
g_tc = nx.transitive_closure_dag(g)
229+
230+
print(
231+
"Remove node without domain descendants that don't have pdb correspondence"
232+
)
233+
# Step 3: Identify and remove nodes that don’t have a "px" descendant with correspondence to pdb_sequences file
234+
nodes_to_remove = set()
235+
for node in g_tc.nodes:
236+
if node not in px_level_nodes and not any(
237+
desc in px_level_nodes for desc in g_tc.successors(node)
238+
):
239+
nodes_to_remove.add(node)
240+
241+
g_tc.remove_nodes_from(nodes_to_remove)
242+
243+
return g_tc
214244

215245
def _get_scope_data(self) -> pd.DataFrame:
216246
"""
@@ -388,7 +418,8 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
388418

389419
encoded_target_columns = []
390420
for level in hierarchy_levels:
391-
encoded_target_columns.extend(lvl_to_target_cols_mapping[level])
421+
if level in lvl_to_target_cols_mapping:
422+
encoded_target_columns.extend(lvl_to_target_cols_mapping[level])
392423

393424
print(
394425
f"{len(encoded_target_columns)} labels has been selected for specified threshold, "
@@ -471,12 +502,12 @@ def _parse_pdb_sequence_file(self) -> pd.DataFrame:
471502
for record in SeqIO.parse(
472503
os.path.join(self.scope_root_dir, self.raw_file_names_dict["PDB"]), "fasta"
473504
):
505+
506+
if not record.seq:
507+
continue
508+
474509
pdb_id, chain = record.id.split("_")
475-
sequence = (
476-
re.sub(f"[^{valid_amino_acids}]", "X", str(record.seq))
477-
if record.seq
478-
else ""
479-
)
510+
sequence = re.sub(f"[^{valid_amino_acids}]", "X", str(record.seq))
480511

481512
# Store as a dictionary entry (list of dicts -> DataFrame later)
482513
records.append(
@@ -876,7 +907,8 @@ class SCOPeOverPartial2000(_SCOPeOverXPartial):
876907

877908

878909
if __name__ == "__main__":
879-
scope = SCOPeOver2000(scope_version="2.08")
910+
scope = SCOPeOver50(scope_version="2.08")
911+
880912
# g = scope._extract_class_hierarchy("dummy/path")
881913
# # Save graph
882914
# import pickle

0 commit comments

Comments
 (0)