|
13 | 13 | import gzip |
14 | 14 | import os |
15 | 15 | import shutil |
16 | | -from abc import ABC |
| 16 | +from abc import ABC, abstractmethod |
17 | 17 | from tempfile import NamedTemporaryFile |
18 | | -from typing import Any, Dict, Generator, Optional, Tuple |
| 18 | +from typing import Any, Dict, Generator, List, Optional, Tuple |
19 | 19 |
|
20 | 20 | import networkx as nx |
21 | 21 | import pandas as pd |
@@ -350,21 +350,7 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: |
350 | 350 | """ |
351 | 351 | print(f"Process graph") |
352 | 352 |
|
353 | | - sids = nx.get_node_attributes(graph, "sid") |
354 | | - levels = nx.get_node_attributes(graph, "level") |
355 | | - |
356 | | - sun_ids = {} |
357 | | - sids_list = [] |
358 | | - |
359 | | - selected_sids_dict = self.select_classes(graph) |
360 | | - |
361 | | - for sun_id, level in levels.items(): |
362 | | - if sun_id in selected_sids_dict: |
363 | | - sun_ids.setdefault(level, []).append(sun_id) |
364 | | - sids_list.append(sids.get(sun_id)) |
365 | | - |
366 | | - # Remove root node, as it will True for all instances |
367 | | - sun_ids.pop("root", None) |
| 353 | + sun_ids = self.select_classes(graph) |
368 | 354 |
|
369 | 355 | if not sun_ids: |
370 | 356 | raise RuntimeError("No sunid selected.") |
@@ -440,6 +426,10 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: |
440 | 426 | sequence_hierarchy_df = sequence_hierarchy_df[ |
441 | 427 | ["id", "sids", "sequence"] + encoded_target_columns |
442 | 428 | ] |
| 429 | + |
| 430 | + with open(os.path.join(self.processed_dir_main, "classes.txt"), "wt") as fout: |
| 431 | + fout.writelines(str(sun_id) + "\n" for sun_id in encoded_target_columns) |
| 432 | + |
443 | 433 | return sequence_hierarchy_df |
444 | 434 |
|
445 | 435 | def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]: |
@@ -498,6 +488,11 @@ def _update_or_add_sequence( |
498 | 488 | new_row["sids"] = [row["sid"]] |
499 | 489 | sequence_hierarchy_df.loc[sequence] = new_row |
500 | 490 |
|
| 491 | + @abstractmethod |
| 492 | + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]]: |
| 493 | + # Override the return type of the method from superclass |
| 494 | + pass |
| 495 | + |
501 | 496 | # ------------------------------ Phase: Setup data ----------------------------------- |
502 | 497 | def setup_processed(self) -> None: |
503 | 498 | """ |
@@ -755,24 +750,36 @@ def _name(self) -> str: |
755 | 750 | """ |
756 | 751 | return f"SCOPe{self.THRESHOLD}" |
757 | 752 |
|
758 | | - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict: |
759 | | - # Filter nodes and create a dictionary of node and out-degree |
760 | | - sun_ids_dict = { |
761 | | - node: g.out_degree(node) # Store node and its out-degree |
762 | | - for node in g.nodes |
763 | | - if g.out_degree(node) >= self.THRESHOLD |
764 | | - } |
| 753 | + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]]: |
| 754 | + """ |
| 755 | + Selects classes from the SCOPe dataset based on the number of successors meeting a specified threshold. |
765 | 756 |
|
766 | | - # Return a sorted dictionary (by out-degree or node id) |
767 | | - sorted_dict = dict( |
768 | | - sorted(sun_ids_dict.items(), key=lambda item: item[0], reverse=False) |
769 | | - ) |
| 757 | + This method iterates over the nodes in the graph, counting the number of successors for each node. |
| 758 | + Nodes with a number of successors greater than or equal to the defined threshold are selected. |
| 759 | +
|
| 760 | + Note: |
| 761 | + The input graph must be transitive closure of a directed acyclic graph. |
770 | 762 |
|
771 | | - filename = "classes.txt" |
772 | | - with open(os.path.join(self.processed_dir_main, filename), "wt") as fout: |
773 | | - fout.writelines(str(sun_id) + "\n" for sun_id in sorted_dict.keys()) |
| 763 | + Args: |
| 764 | + g (nx.Graph): The graph representing the dataset. |
| 765 | + *args: Additional positional arguments (not used). |
| 766 | + **kwargs: Additional keyword arguments (not used). |
| 767 | +
|
| 768 | + Returns: |
| 769 | + Dict: A dict containing selected nodes at each hierarchy level. |
774 | 770 |
|
775 | | - return sorted_dict |
| 771 | + Notes: |
| 772 | + - The `THRESHOLD` attribute should be defined in the subclass of this class. |
| 773 | + """ |
| 774 | + selected_sunids_for_level = {} |
| 775 | + for node, attr_dict in g.nodes(data=True): |
| 776 | + if g.out_degree(node) >= self.THRESHOLD: |
| 777 | + selected_sunids_for_level.setdefault(attr_dict["level"], []).append( |
| 778 | + node |
| 779 | + ) |
| 780 | + # Remove root node, as it will True for all instances |
| 781 | + selected_sunids_for_level.pop("root", None) |
| 782 | + return selected_sunids_for_level |
776 | 783 |
|
777 | 784 |
|
778 | 785 | class _SCOPeOverXPartial(_SCOPeOverX, ABC): |
@@ -860,6 +867,6 @@ class SCOPeOverPartial2000(_SCOPeOverXPartial): |
860 | 867 |
|
861 | 868 |
|
862 | 869 | if __name__ == "__main__": |
863 | | - scope = SCOPE(scope_version=2.08) |
864 | | - g = scope._extract_class_hierarchy("d") |
| 870 | + scope = SCOPeOver2000(scope_version=2.08) |
| 871 | + g = scope._extract_class_hierarchy("dummy/path") |
865 | 872 | scope._graph_to_raw_dataset(g) |
0 commit comments