|
| 1 | +__all__ = ["SwissProteinPretrain"] |
| 2 | + |
| 3 | +import os |
| 4 | +from abc import ABC |
| 5 | +from collections import OrderedDict |
| 6 | +from typing import Any, Dict, Generator, List, Tuple |
| 7 | + |
| 8 | +import networkx as nx |
| 9 | +import pandas as pd |
| 10 | +import torch |
| 11 | +from Bio import SwissProt |
| 12 | +from sklearn.model_selection import train_test_split |
| 13 | + |
| 14 | +from chebai.preprocessing.datasets.base import _DynamicDataset |
| 15 | +from chebai.preprocessing.datasets.go_uniprot import ( |
| 16 | + AMBIGUOUS_AMINO_ACIDS, |
| 17 | + EXPERIMENTAL_EVIDENCE_CODES, |
| 18 | + GOUniProtOver250, |
| 19 | +) |
| 20 | +from chebai.preprocessing.reader import ProteinDataReader |
| 21 | + |
| 22 | + |
| 23 | +class _ProteinPretrainingData(_DynamicDataset, ABC): |
| 24 | + """ |
| 25 | + Data module for pretraining protein sequences, specifically designed for Swiss-UniProt data. It includes methods for |
| 26 | + data preparation, loading, and dynamic splitting of protein sequences. |
| 27 | + The data is parsed and filtered to only select proteins with no associated `valid` Gene Ontology (GO) labels. |
| 28 | + A valid GO label is the one which has one of evidence codes defined in `EXPERIMENTAL_EVIDENCE_CODES`. |
| 29 | + """ |
| 30 | + |
| 31 | + _ID_IDX: int = 0 |
| 32 | + _DATA_REPRESENTATION_IDX: int = 1 # Index of `sequence` column |
| 33 | + |
| 34 | + def __init__(self, **kwargs): |
| 35 | + """ |
| 36 | + Initializes the data module with any GOUniProt extractor class object. |
| 37 | +
|
| 38 | + Args: |
| 39 | + **kwargs: Additional arguments for the superclass initialization. |
| 40 | + """ |
| 41 | + self._go_uniprot_extractor = GOUniProtOver250() |
| 42 | + assert self._go_uniprot_extractor.go_branch == GOUniProtOver250._ALL_GO_BRANCHES |
| 43 | + |
| 44 | + self.max_sequence_length: int = int(kwargs.get("max_sequence_length", 1002)) |
| 45 | + assert ( |
| 46 | + self.max_sequence_length >= 1 |
| 47 | + ), "Max sequence length should be greater than or equal to 1." |
| 48 | + |
| 49 | + super(_ProteinPretrainingData, self).__init__(**kwargs) |
| 50 | + |
| 51 | + if self.reader.n_gram is not None: |
| 52 | + assert self.max_sequence_length >= self.reader.n_gram, ( |
| 53 | + f"max_sequence_length ({self.max_sequence_length}) must be greater than " |
| 54 | + f"or equal to n_gram ({self.reader.n_gram})." |
| 55 | + ) |
| 56 | + |
| 57 | + # ------------------------------ Phase: Prepare data ----------------------------------- |
| 58 | + def prepare_data(self, *args: Any, **kwargs: Any) -> None: |
| 59 | + """ |
| 60 | + Prepares the data by downloading and parsing Swiss-Prot data if not already available. Saves the processed data |
| 61 | + for further use. |
| 62 | +
|
| 63 | + Args: |
| 64 | + *args: Additional positional arguments. |
| 65 | + **kwargs: Additional keyword arguments. |
| 66 | + """ |
| 67 | + processed_name = self.processed_dir_main_file_names_dict["data"] |
| 68 | + if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): |
| 69 | + print("Missing processed data file (`data.pkl` file)") |
| 70 | + os.makedirs(self.processed_dir_main, exist_ok=True) |
| 71 | + self._download_required_data() |
| 72 | + protein_df = self._parse_protein_data_for_pretraining() |
| 73 | + self.save_processed(protein_df, processed_name) |
| 74 | + |
| 75 | + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: |
| 76 | + # method not required as no Swiss-UniProt has no ontological data |
| 77 | + pass |
| 78 | + |
| 79 | + def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: |
| 80 | + # method not required as no Swiss-UniProt has no ontological data |
| 81 | + pass |
| 82 | + |
| 83 | + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: |
| 84 | + # method not required as no Swiss-UniProt has no ontological data |
| 85 | + pass |
| 86 | + |
| 87 | + def _download_required_data(self) -> str: |
| 88 | + """ |
| 89 | + Downloads the required Swiss-Prot data using the GOUniProt extractor class. |
| 90 | +
|
| 91 | + Returns: |
| 92 | + str: Path to the downloaded data. |
| 93 | + """ |
| 94 | + return self._go_uniprot_extractor._download_swiss_uni_prot_data() |
| 95 | + |
| 96 | + def _parse_protein_data_for_pretraining(self) -> pd.DataFrame: |
| 97 | + """ |
| 98 | + Parses the Swiss-Prot data and returns a DataFrame containing Swiss-Prot proteins which does not have any valid |
| 99 | + Gene Ontology(GO) label. A valid GO label is the one which has one of the following evidence code |
| 100 | + (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC). |
| 101 | +
|
| 102 | + The DataFrame includes the following columns: |
| 103 | + - "swiss_id": The unique identifier for each Swiss-Prot record. |
| 104 | + - "sequence": The protein sequence. |
| 105 | +
|
| 106 | + Note: |
| 107 | + We ignore proteins with ambiguous amino acid codes (B, O, J, U, X, Z) in their sequence.` |
| 108 | +
|
| 109 | + Returns: |
| 110 | + pd.DataFrame: A DataFrame where each row corresponds to a Swiss-Prot record with not associated valid GO. |
| 111 | + """ |
| 112 | + print("Parsing swiss uniprot raw data....") |
| 113 | + |
| 114 | + swiss_ids, sequences = [], [] |
| 115 | + |
| 116 | + swiss_data = SwissProt.parse( |
| 117 | + open( |
| 118 | + os.path.join( |
| 119 | + self._go_uniprot_extractor.raw_dir, |
| 120 | + self._go_uniprot_extractor.raw_file_names_dict["SwissUniProt"], |
| 121 | + ), |
| 122 | + "r", |
| 123 | + ) |
| 124 | + ) |
| 125 | + |
| 126 | + for record in swiss_data: |
| 127 | + if record.data_class != "Reviewed": |
| 128 | + # To consider only manually-annotated swiss data |
| 129 | + continue |
| 130 | + |
| 131 | + if not record.sequence: |
| 132 | + # Consider protein with only sequence representation |
| 133 | + continue |
| 134 | + |
| 135 | + if len(record.sequence) > self.max_sequence_length: |
| 136 | + # Consider protein with only sequence length not greater than max seq. length |
| 137 | + continue |
| 138 | + |
| 139 | + if any(aa in AMBIGUOUS_AMINO_ACIDS for aa in record.sequence): |
| 140 | + # Skip proteins with ambiguous amino acid codes |
| 141 | + continue |
| 142 | + |
| 143 | + has_valid_associated_go_label = False |
| 144 | + for cross_ref in record.cross_references: |
| 145 | + if cross_ref[0] == self._go_uniprot_extractor._GO_DATA_INIT: |
| 146 | + |
| 147 | + if len(cross_ref) <= 3: |
| 148 | + # No evidence code |
| 149 | + continue |
| 150 | + |
| 151 | + # https://github.com/bio-ontology-research-group/deepgo/blob/master/get_functions.py#L63-L66 |
| 152 | + evidence_code = cross_ref[3].split(":")[0] |
| 153 | + if evidence_code in EXPERIMENTAL_EVIDENCE_CODES: |
| 154 | + has_valid_associated_go_label = True |
| 155 | + break |
| 156 | + |
| 157 | + if has_valid_associated_go_label: |
| 158 | + # Skip proteins which has at least one associated go label |
| 159 | + continue |
| 160 | + |
| 161 | + swiss_ids.append(record.entry_name) |
| 162 | + sequences.append(record.sequence) |
| 163 | + |
| 164 | + data_dict = OrderedDict( |
| 165 | + swiss_id=swiss_ids, # swiss_id column at index 0 |
| 166 | + sequence=sequences, # Sequence column at index 1 |
| 167 | + ) |
| 168 | + |
| 169 | + return pd.DataFrame(data_dict) |
| 170 | + |
| 171 | + # ------------------------------ Phase: Setup data ----------------------------------- |
| 172 | + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: |
| 173 | + """ |
| 174 | + Loads data from a pickled file and yields individual dictionaries for each row. |
| 175 | +
|
| 176 | + The pickled file is expected to contain rows with the following structure: |
| 177 | + - Data at row index `self._ID_IDX`: ID of go data instance |
| 178 | + - Data at row index `self._DATA_REPRESENTATION_IDX`: Sequence representation of protein |
| 179 | +
|
| 180 | + This method is used by `_load_data_from_file` to generate dictionaries that are then |
| 181 | + processed and converted into a list of dictionaries containing the features and labels. |
| 182 | +
|
| 183 | + Args: |
| 184 | + input_file_path (str): The path to the pickled input file. |
| 185 | +
|
| 186 | + Yields: |
| 187 | + Dict[str, Any]: A dictionary containing: |
| 188 | + - `features` (str): The sequence data from the file. |
| 189 | + - `ident` (Any): The identifier from row index 0. |
| 190 | + - `labels`: Set to None |
| 191 | + """ |
| 192 | + with open(input_file_path, "rb") as input_file: |
| 193 | + df = pd.read_pickle(input_file) |
| 194 | + for row in df.values: |
| 195 | + yield dict( |
| 196 | + features=row[self._DATA_REPRESENTATION_IDX], |
| 197 | + ident=row[self._ID_IDX], |
| 198 | + labels=None, |
| 199 | + ) |
| 200 | + |
| 201 | + # ------------------------------ Phase: Dynamic Splits ----------------------------------- |
| 202 | + def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: |
| 203 | + """ |
| 204 | + Loads encoded data and generates training, validation, and test splits. |
| 205 | +
|
| 206 | + This method attempts to load encoded data from a file named `data.pt`. It then splits this data into |
| 207 | + training, validation, and test sets. |
| 208 | +
|
| 209 | + Raises: |
| 210 | + FileNotFoundError: If the `data.pt` file does not exist. Ensure that `prepare_data` and/or |
| 211 | + `setup` methods are called to generate the necessary dataset files. |
| 212 | +
|
| 213 | + Returns: |
| 214 | + Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing three DataFrames: |
| 215 | + - Training set |
| 216 | + - Validation set |
| 217 | + - Test set |
| 218 | + """ |
| 219 | + try: |
| 220 | + filename = self.processed_file_names_dict["data"] |
| 221 | + data_go = torch.load( |
| 222 | + os.path.join(self.processed_dir, filename), weights_only=False |
| 223 | + ) |
| 224 | + except FileNotFoundError: |
| 225 | + raise FileNotFoundError( |
| 226 | + f"File data.pt doesn't exists. " |
| 227 | + f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" |
| 228 | + ) |
| 229 | + |
| 230 | + df_go_data = pd.DataFrame(data_go) |
| 231 | + train_df_go, df_test = train_test_split( |
| 232 | + df_go_data, |
| 233 | + train_size=self.train_split, |
| 234 | + random_state=self.dynamic_data_split_seed, |
| 235 | + ) |
| 236 | + |
| 237 | + # Get all splits |
| 238 | + df_train, df_val = train_test_split( |
| 239 | + train_df_go, |
| 240 | + train_size=self.train_split, |
| 241 | + random_state=self.dynamic_data_split_seed, |
| 242 | + ) |
| 243 | + |
| 244 | + return df_train, df_val, df_test |
| 245 | + |
| 246 | + # ------------------------------ Phase: Raw Properties ----------------------------------- |
| 247 | + @property |
| 248 | + def base_dir(self) -> str: |
| 249 | + """ |
| 250 | + str: The base directory for pretraining data storage. |
| 251 | + """ |
| 252 | + return os.path.join(self._go_uniprot_extractor.base_dir, "Pretraining") |
| 253 | + |
| 254 | + @property |
| 255 | + def raw_dir(self) -> str: |
| 256 | + """Name of the directory where the raw data is stored.""" |
| 257 | + return self._go_uniprot_extractor.raw_dir |
| 258 | + |
| 259 | + |
| 260 | +class SwissProteinPretrain(_ProteinPretrainingData): |
| 261 | + """ |
| 262 | + Data module for Swiss-Prot protein pretraining, inheriting from `_ProteinPretrainingData`. |
| 263 | + This class is specifically designed to handle data processing and loading for Swiss-Prot-based protein datasets. |
| 264 | +
|
| 265 | + Attributes: |
| 266 | + READER (Type): The data reader class used to load and process protein pretraining data. |
| 267 | + """ |
| 268 | + |
| 269 | + READER = ProteinDataReader |
| 270 | + |
| 271 | + @property |
| 272 | + def _name(self) -> str: |
| 273 | + """ |
| 274 | + The name identifier for this data module. |
| 275 | +
|
| 276 | + Returns: |
| 277 | + str: A string identifier, "SwissProteinPretrain", representing the name of this data module. |
| 278 | + """ |
| 279 | + return f"Swiss_{self.max_sequence_length}" |
0 commit comments