Skip to content

Commit c3ba8da

Browse files
committed
scope: add OverX classes and their derivaties
1 parent 6735e41 commit c3ba8da

File tree

1 file changed

+124
-10
lines changed
  • chebai/preprocessing/datasets/scope

1 file changed

+124
-10
lines changed

chebai/preprocessing/datasets/scope/scope.py

Lines changed: 124 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ def _download_pdb_sequence_data(self) -> None:
121121
The file is downloaded as a temporary gzip file, which is then extracted to the
122122
specified directory.
123123
"""
124-
pdb_seq_file_path = os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"])
124+
pdb_seq_file_path = os.path.join(
125+
self.scope_root_dir, self.raw_file_names_dict["PDB"]
126+
)
125127
os.makedirs(os.path.dirname(pdb_seq_file_path), exist_ok=True)
126128

127129
if not os.path.isfile(pdb_seq_file_path):
@@ -450,7 +452,7 @@ def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]:
450452
"""
451453
pdb_chain_seq_mapping: Dict[str, Dict[str, str]] = {}
452454
for record in SeqIO.parse(
453-
os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]), "fasta"
455+
os.path.join(self.scope_root_dir, self.raw_file_names_dict["PDB"]), "fasta"
454456
):
455457
pdb_id, chain = record.id.split("_")
456458
if str(record.seq):
@@ -655,11 +657,11 @@ def _setup_pruned_test_set(
655657
# filename_new = f"classes_v{self.scope_version_train}.txt"
656658
# dataset = torch.load(os.path.join(self.processed_dir, "test.pt"))
657659

658-
# Load original classes (from the current ChEBI version - scope_version)
660+
# Load original classes (from the current SCOPe version - scope_version)
659661
with open(os.path.join(self.processed_dir_main, filename_old), "r") as file:
660662
orig_classes = file.readlines()
661663

662-
# Load new classes (from the training ChEBI version - scope_version_train)
664+
# Load new classes (from the training SCOPe version - scope_version_train)
663665
with open(
664666
os.path.join(
665667
self._scope_version_train_obj.processed_dir_main, filename_old
@@ -690,15 +692,25 @@ def _setup_pruned_test_set(
690692
return df_test_scope_version
691693

692694
# ------------------------------ Phase: Raw Properties -----------------------------------
695+
@property
696+
def scope_root_dir(self) -> str:
697+
"""
698+
Returns the root directory of scope data.
699+
700+
Returns:
701+
str: The path to the base directory, which is "data/GO_UniProt".
702+
"""
703+
return os.path.join("data", "SCOPe")
704+
693705
@property
694706
def base_dir(self) -> str:
695707
"""
696-
Returns the base directory path for storing GO-Uniprot data.
708+
Returns the base directory path for storing SCOPe data.
697709
698710
Returns:
699711
str: The path to the base directory, which is "data/GO_UniProt".
700712
"""
701-
return os.path.join("data", "SCOPe", f"version_{self.scope_version}")
713+
return os.path.join(self.scope_root_dir, f"version_{self.scope_version}")
702714

703715
@property
704716
def raw_file_names_dict(self) -> dict:
@@ -707,7 +719,6 @@ def raw_file_names_dict(self) -> dict:
707719
708720
Returns:
709721
dict: A dictionary mapping dataset names to their respective file names.
710-
For example, {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"}.
711722
"""
712723
return {
713724
"CLA": "cla.txt",
@@ -717,13 +728,32 @@ def raw_file_names_dict(self) -> dict:
717728
}
718729

719730

720-
class SCOPE(_SCOPeDataExtractor):
731+
class _SCOPeOverX(_SCOPeDataExtractor, ABC):
732+
"""
733+
A class for extracting data from the SCOPe dataset with a threshold for selecting classes/labels based on
734+
the number of subclasses.
735+
736+
This class is designed to filter SCOPe classes/labels based on a specified threshold, selecting only those classes
737+
which have a certain number of subclasses in the hierarchy.
738+
739+
Attributes:
740+
READER (dr.ProteinDataReader): The reader used for reading the dataset.
741+
THRESHOLD (int): The threshold for selecting classes/labels based on the number of subclasses.
742+
743+
"""
744+
721745
READER = ProteinDataReader
722-
THRESHOLD = 2143
746+
THRESHOLD: int = None
723747

724748
@property
725749
def _name(self) -> str:
726-
return "test"
750+
"""
751+
Returns the name of the dataset.
752+
753+
Returns:
754+
str: The dataset name, formatted with the current threshold.
755+
"""
756+
return f"SCOPe{self.THRESHOLD}"
727757

728758
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict:
729759
# Filter nodes and create a dictionary of node and out-degree
@@ -745,6 +775,90 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict:
745775
return sorted_dict
746776

747777

778+
class _SCOPeOverXPartial(_SCOPeOverX, ABC):
779+
"""
780+
Dataset that doesn't use the full SCOPe dataset, but extracts a part of SCOPe (subclasses of a given top class)
781+
782+
Attributes:
783+
top_class_sunid (int): The Sun-ID of the top class from which to extract subclasses.
784+
"""
785+
786+
def __init__(self, top_class_sunid: int, **kwargs):
787+
"""
788+
Initializes the _SCOPeOverXPartial dataset.
789+
790+
Args:
791+
top_class_sunid (int): The Sun-ID of the top class from which to extract subclasses.
792+
**kwargs: Additional keyword arguments passed to the superclass initializer.
793+
"""
794+
if "top_class_sunid" not in kwargs:
795+
kwargs["top_class_sunid"] = top_class_sunid
796+
797+
self.top_class_sunid: int = top_class_sunid
798+
super().__init__(**kwargs)
799+
800+
@property
801+
def processed_dir_main(self) -> str:
802+
"""
803+
Returns the main processed data directory specific to the top class.
804+
805+
Returns:
806+
str: The processed data directory path.
807+
"""
808+
return os.path.join(
809+
self.base_dir,
810+
self._name,
811+
f"partial_{self.top_class_sunid}",
812+
"processed",
813+
)
814+
815+
def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
816+
"""
817+
Extracts a subset of SCOPe based on subclasses of the top class ID.
818+
819+
This method calls the superclass method to extract the full class hierarchy,
820+
then extracts the subgraph containing only the descendants of the top class ID, including itself.
821+
822+
Args:
823+
data_path (str): The file path to the SCOPe ontology file.
824+
825+
Returns:
826+
nx.DiGraph: The extracted class hierarchy as a directed graph, limited to the
827+
descendants of the top class ID.
828+
"""
829+
g = super()._extract_class_hierarchy(data_path)
830+
g = g.subgraph(
831+
list(g.successors(self.top_class_sunid)) + [self.top_class_sunid]
832+
)
833+
return g
834+
835+
836+
class SCOPeOver2000(_SCOPeOverX):
837+
"""
838+
A class for extracting data from the SCOPe dataset with a threshold of 2000 for selecting classes.
839+
840+
Inherits from `_SCOPeOverX` and sets the threshold for selecting classes to 2000.
841+
842+
Attributes:
843+
THRESHOLD (int): The threshold for selecting classes (2000).
844+
"""
845+
846+
THRESHOLD: int = 2000
847+
848+
849+
class SCOPeOverPartial2000(_SCOPeOverXPartial):
850+
"""
851+
A class for extracting data from the SCOPe dataset with a threshold of 2000 for selecting classes.
852+
853+
Inherits from `_SCOPeOverXPartial` and sets the threshold for selecting classes to 2000.
854+
855+
Attributes:
856+
THRESHOLD (int): The threshold for selecting classes (2000).
857+
"""
858+
859+
THRESHOLD: int = 2000
860+
861+
748862
if __name__ == "__main__":
749863
scope = SCOPE(scope_version=2.08)
750864
g = scope._extract_class_hierarchy("d")

0 commit comments

Comments
 (0)