Skip to content

Commit 764b812

Browse files
committed
scope: modify select classes and labels save operation
1 parent c3ba8da commit 764b812

File tree

1 file changed

+41
-34
lines changed
  • chebai/preprocessing/datasets/scope

1 file changed

+41
-34
lines changed

chebai/preprocessing/datasets/scope/scope.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
import gzip
1414
import os
1515
import shutil
16-
from abc import ABC
16+
from abc import ABC, abstractmethod
1717
from tempfile import NamedTemporaryFile
18-
from typing import Any, Dict, Generator, Optional, Tuple
18+
from typing import Any, Dict, Generator, List, Optional, Tuple
1919

2020
import networkx as nx
2121
import pandas as pd
@@ -350,21 +350,7 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
350350
"""
351351
print(f"Process graph")
352352

353-
sids = nx.get_node_attributes(graph, "sid")
354-
levels = nx.get_node_attributes(graph, "level")
355-
356-
sun_ids = {}
357-
sids_list = []
358-
359-
selected_sids_dict = self.select_classes(graph)
360-
361-
for sun_id, level in levels.items():
362-
if sun_id in selected_sids_dict:
363-
sun_ids.setdefault(level, []).append(sun_id)
364-
sids_list.append(sids.get(sun_id))
365-
366-
# Remove root node, as it will True for all instances
367-
sun_ids.pop("root", None)
353+
sun_ids = self.select_classes(graph)
368354

369355
if not sun_ids:
370356
raise RuntimeError("No sunid selected.")
@@ -440,6 +426,10 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
440426
sequence_hierarchy_df = sequence_hierarchy_df[
441427
["id", "sids", "sequence"] + encoded_target_columns
442428
]
429+
430+
with open(os.path.join(self.processed_dir_main, "classes.txt"), "wt") as fout:
431+
fout.writelines(str(sun_id) + "\n" for sun_id in encoded_target_columns)
432+
443433
return sequence_hierarchy_df
444434

445435
def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]:
@@ -498,6 +488,11 @@ def _update_or_add_sequence(
498488
new_row["sids"] = [row["sid"]]
499489
sequence_hierarchy_df.loc[sequence] = new_row
500490

491+
@abstractmethod
492+
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]]:
493+
# Override the return type of the method from superclass
494+
pass
495+
501496
# ------------------------------ Phase: Setup data -----------------------------------
502497
def setup_processed(self) -> None:
503498
"""
@@ -755,24 +750,36 @@ def _name(self) -> str:
755750
"""
756751
return f"SCOPe{self.THRESHOLD}"
757752

758-
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict:
759-
# Filter nodes and create a dictionary of node and out-degree
760-
sun_ids_dict = {
761-
node: g.out_degree(node) # Store node and its out-degree
762-
for node in g.nodes
763-
if g.out_degree(node) >= self.THRESHOLD
764-
}
753+
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]]:
754+
"""
755+
Selects classes from the SCOPe dataset based on the number of successors meeting a specified threshold.
765756
766-
# Return a sorted dictionary (by out-degree or node id)
767-
sorted_dict = dict(
768-
sorted(sun_ids_dict.items(), key=lambda item: item[0], reverse=False)
769-
)
757+
This method iterates over the nodes in the graph, counting the number of successors for each node.
758+
Nodes with a number of successors greater than or equal to the defined threshold are selected.
759+
760+
Note:
761+
The input graph must be transitive closure of a directed acyclic graph.
770762
771-
filename = "classes.txt"
772-
with open(os.path.join(self.processed_dir_main, filename), "wt") as fout:
773-
fout.writelines(str(sun_id) + "\n" for sun_id in sorted_dict.keys())
763+
Args:
764+
g (nx.Graph): The graph representing the dataset.
765+
*args: Additional positional arguments (not used).
766+
**kwargs: Additional keyword arguments (not used).
767+
768+
Returns:
769+
Dict: A dict containing selected nodes at each hierarchy level.
774770
775-
return sorted_dict
771+
Notes:
772+
- The `THRESHOLD` attribute should be defined in the subclass of this class.
773+
"""
774+
selected_sunids_for_level = {}
775+
for node, attr_dict in g.nodes(data=True):
776+
if g.out_degree(node) >= self.THRESHOLD:
777+
selected_sunids_for_level.setdefault(attr_dict["level"], []).append(
778+
node
779+
)
780+
# Remove root node, as it will True for all instances
781+
selected_sunids_for_level.pop("root", None)
782+
return selected_sunids_for_level
776783

777784

778785
class _SCOPeOverXPartial(_SCOPeOverX, ABC):
@@ -860,6 +867,6 @@ class SCOPeOverPartial2000(_SCOPeOverXPartial):
860867

861868

862869
if __name__ == "__main__":
863-
scope = SCOPE(scope_version=2.08)
864-
g = scope._extract_class_hierarchy("d")
870+
scope = SCOPeOver2000(scope_version=2.08)
871+
g = scope._extract_class_hierarchy("dummy/path")
865872
scope._graph_to_raw_dataset(g)

0 commit comments

Comments
 (0)