|
| 1 | +import os |
| 2 | +from collections import OrderedDict |
| 3 | +from typing import List, Literal, Optional |
| 4 | + |
| 5 | +import pandas as pd |
| 6 | +from jsonargparse import CLI |
| 7 | + |
| 8 | +from chebai.preprocessing.datasets.go_uniprot import ( |
| 9 | + GOUniProtOver50, |
| 10 | + GOUniProtOver250, |
| 11 | + _GOUniProtDataExtractor, |
| 12 | +) |
| 13 | + |
| 14 | + |
| 15 | +class DeepGo1DataMigration: |
| 16 | + """ |
| 17 | + A class to handle data migration and processing for the DeepGO project. |
| 18 | + It migrates the deepGO data to our data structure followed for GO-UniProt data. |
| 19 | +
|
| 20 | + It migrates the data of DeepGO model of the below research paper: |
| 21 | + Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, |
| 22 | + DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, |
| 23 | + Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668 |
| 24 | + (https://doi.org/10.1093/bioinformatics/btx624), |
| 25 | +
|
| 26 | + Attributes: |
| 27 | + _CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes. |
| 28 | + _MAXLEN (int): Maximum sequence length for sequences. |
| 29 | + _LABELS_START_IDX (int): Starting index for labels in the dataset. |
| 30 | +
|
| 31 | + Methods: |
| 32 | + __init__(data_dir, go_branch): Initializes the data directory and GO branch. |
| 33 | + _load_data(): Loads train, validation, test, and terms data from the specified directory. |
| 34 | + _record_splits(): Creates a DataFrame with IDs and their corresponding split. |
| 35 | + migrate(): Executes the migration process including data loading, processing, and saving. |
| 36 | + _extract_required_data_from_splits(): Extracts required columns from the splits data. |
| 37 | + _get_labels_columns(data_df): Generates label columns for the data based on GO terms. |
| 38 | + extract_go_id(go_list): Extracts GO IDs from a list. |
| 39 | + save_migrated_data(data_df, splits_df): Saves the processed data and splits. |
| 40 | + """ |
| 41 | + |
| 42 | + # Number of annotations for each go_branch as per the research paper |
| 43 | + _CORRESPONDING_GO_CLASSES = { |
| 44 | + "cc": GOUniProtOver50, |
| 45 | + "mf": GOUniProtOver50, |
| 46 | + "bp": GOUniProtOver250, |
| 47 | + } |
| 48 | + |
| 49 | + _MAXLEN = 1002 |
| 50 | + _LABELS_START_IDX = _GOUniProtDataExtractor._LABELS_START_IDX |
| 51 | + |
| 52 | + def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): |
| 53 | + """ |
| 54 | + Initializes the data migration object with a data directory and GO branch. |
| 55 | +
|
| 56 | + Args: |
| 57 | + data_dir (str): Directory containing the data files. |
| 58 | + go_branch (Literal["cc", "mf", "bp"]): GO branch to use (cellular_component, molecular_function, or biological_process). |
| 59 | + """ |
| 60 | + valid_go_branches = list(self._CORRESPONDING_GO_CLASSES.keys()) |
| 61 | + if go_branch not in valid_go_branches: |
| 62 | + raise ValueError(f"go_branch must be one of {valid_go_branches}") |
| 63 | + self._go_branch = go_branch |
| 64 | + |
| 65 | + self._data_dir: str = rf"{data_dir}" |
| 66 | + self._train_df: Optional[pd.DataFrame] = None |
| 67 | + self._test_df: Optional[pd.DataFrame] = None |
| 68 | + self._validation_df: Optional[pd.DataFrame] = None |
| 69 | + self._terms_df: Optional[pd.DataFrame] = None |
| 70 | + self._classes: Optional[List[str]] = None |
| 71 | + |
| 72 | + def _load_data(self) -> None: |
| 73 | + """ |
| 74 | + Loads the test, train, validation, and terms data from the pickled files |
| 75 | + in the data directory. |
| 76 | + """ |
| 77 | + try: |
| 78 | + print(f"Loading data from {self._data_dir}......") |
| 79 | + self._test_df = pd.DataFrame( |
| 80 | + pd.read_pickle( |
| 81 | + os.path.join(self._data_dir, f"test-{self._go_branch}.pkl") |
| 82 | + ) |
| 83 | + ) |
| 84 | + self._train_df = pd.DataFrame( |
| 85 | + pd.read_pickle( |
| 86 | + os.path.join(self._data_dir, f"train-{self._go_branch}.pkl") |
| 87 | + ) |
| 88 | + ) |
| 89 | + # self._validation_df = pd.DataFrame( |
| 90 | + # pd.read_pickle(os.path.join(self._data_dir, f"valid-{self._go_branch}.pkl")) |
| 91 | + # ) |
| 92 | + self._terms_df = pd.DataFrame( |
| 93 | + pd.read_pickle(os.path.join(self._data_dir, f"{self._go_branch}.pkl")) |
| 94 | + ) |
| 95 | + |
| 96 | + except FileNotFoundError as e: |
| 97 | + print(f"Error loading data: {e}") |
| 98 | + |
| 99 | + def _record_splits(self) -> pd.DataFrame: |
| 100 | + """ |
| 101 | + Creates a DataFrame that stores the IDs and their corresponding data splits. |
| 102 | +
|
| 103 | + Returns: |
| 104 | + pd.DataFrame: A combined DataFrame containing split assignments. |
| 105 | + """ |
| 106 | + print("Recording splits...") |
| 107 | + split_assignment_list: List[pd.DataFrame] = [ |
| 108 | + pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), |
| 109 | + # pd.DataFrame( |
| 110 | + # {"id": self._validation_df["proteins"], "split": "validation"} |
| 111 | + # ), |
| 112 | + pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}), |
| 113 | + ] |
| 114 | + |
| 115 | + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) |
| 116 | + return combined_split_assignment |
| 117 | + |
| 118 | + def migrate(self) -> None: |
| 119 | + """ |
| 120 | + Executes the data migration by loading, processing, and saving the data. |
| 121 | + """ |
| 122 | + print("Migration started......") |
| 123 | + self._load_data() |
| 124 | + if not all( |
| 125 | + df is not None |
| 126 | + for df in [ |
| 127 | + self._train_df, |
| 128 | + # self._validation_df, |
| 129 | + self._test_df, |
| 130 | + self._terms_df, |
| 131 | + ] |
| 132 | + ): |
| 133 | + raise Exception( |
| 134 | + "Data splits or terms data is not available in instance variables." |
| 135 | + ) |
| 136 | + splits_df = self._record_splits() |
| 137 | + |
| 138 | + data_with_labels_df = self._extract_required_data_from_splits() |
| 139 | + |
| 140 | + if not all( |
| 141 | + var is not None for var in [data_with_labels_df, splits_df, self._classes] |
| 142 | + ): |
| 143 | + raise Exception( |
| 144 | + "Data splits or terms data is not available in instance variables." |
| 145 | + ) |
| 146 | + |
| 147 | + self.save_migrated_data(data_with_labels_df, splits_df) |
| 148 | + |
| 149 | + def _extract_required_data_from_splits(self) -> pd.DataFrame: |
| 150 | + """ |
| 151 | + Extracts required columns from the combined data splits. |
| 152 | +
|
| 153 | + Returns: |
| 154 | + pd.DataFrame: A DataFrame containing the essential columns for processing. |
| 155 | + """ |
| 156 | + print("Combining the data splits with required data..... ") |
| 157 | + required_columns = [ |
| 158 | + "proteins", |
| 159 | + "accessions", |
| 160 | + "sequences", |
| 161 | + # Note: The GO classes here only directly related one, and not transitive GO classes |
| 162 | + "gos", |
| 163 | + "labels", |
| 164 | + ] |
| 165 | + |
| 166 | + new_df = pd.concat( |
| 167 | + [ |
| 168 | + self._train_df[required_columns], |
| 169 | + # self._validation_df[required_columns], |
| 170 | + self._test_df[required_columns], |
| 171 | + ], |
| 172 | + ignore_index=True, |
| 173 | + ) |
| 174 | + new_df["go_ids"] = new_df.apply( |
| 175 | + lambda row: self.extract_go_id(row["gos"]), axis=1 |
| 176 | + ) |
| 177 | + |
| 178 | + labels_df = self._get_labels_colums(new_df) |
| 179 | + |
| 180 | + data_df = pd.DataFrame( |
| 181 | + OrderedDict( |
| 182 | + swiss_id=new_df["proteins"], |
| 183 | + accession=new_df["accessions"], |
| 184 | + go_ids=new_df["go_ids"], |
| 185 | + sequence=new_df["sequences"], |
| 186 | + ) |
| 187 | + ) |
| 188 | + |
| 189 | + df = pd.concat([data_df, labels_df], axis=1) |
| 190 | + |
| 191 | + return df |
| 192 | + |
| 193 | + def _get_labels_colums(self, data_df: pd.DataFrame) -> pd.DataFrame: |
| 194 | + """ |
| 195 | + Generates a DataFrame with one-hot encoded columns for each GO term label, |
| 196 | + based on the terms provided in `self._terms_df` and the existing labels in `data_df`. |
| 197 | +
|
| 198 | + This method extracts GO IDs from the `functions` column of `self._terms_df`, |
| 199 | + creating a list of all unique GO IDs. It then uses this list to create new |
| 200 | + columns in the returned DataFrame, where each row has binary values |
| 201 | + (0 or 1) indicating the presence of each GO ID in the corresponding entry of |
| 202 | + `data_df['labels']`. |
| 203 | +
|
| 204 | + Args: |
| 205 | + data_df (pd.DataFrame): DataFrame containing data with a 'labels' column, |
| 206 | + which holds lists of GO ID labels for each row. |
| 207 | +
|
| 208 | + Returns: |
| 209 | + pd.DataFrame: A DataFrame with the same index as `data_df` and one column |
| 210 | + per GO ID, containing binary values indicating label presence. |
| 211 | + """ |
| 212 | + print("Generating labels based on terms.pkl file.......") |
| 213 | + parsed_go_ids: pd.Series = self._terms_df["functions"].apply( |
| 214 | + lambda gos: _GOUniProtDataExtractor._parse_go_id(gos) |
| 215 | + ) |
| 216 | + all_go_ids_list = parsed_go_ids.values.tolist() |
| 217 | + self._classes = all_go_ids_list |
| 218 | + |
| 219 | + new_label_columns = pd.DataFrame( |
| 220 | + data_df["labels"].tolist(), index=data_df.index, columns=all_go_ids_list |
| 221 | + ) |
| 222 | + |
| 223 | + return new_label_columns |
| 224 | + |
| 225 | + @staticmethod |
| 226 | + def extract_go_id(go_list: List[str]) -> List[int]: |
| 227 | + """ |
| 228 | + Extracts and parses GO IDs from a list of GO annotations. |
| 229 | +
|
| 230 | + Args: |
| 231 | + go_list (List[str]): List of GO annotation strings. |
| 232 | +
|
| 233 | + Returns: |
| 234 | + List[str]: List of parsed GO IDs. |
| 235 | + """ |
| 236 | + return [ |
| 237 | + _GOUniProtDataExtractor._parse_go_id(go_id_str) for go_id_str in go_list |
| 238 | + ] |
| 239 | + |
| 240 | + def save_migrated_data( |
| 241 | + self, data_df: pd.DataFrame, splits_df: pd.DataFrame |
| 242 | + ) -> None: |
| 243 | + """ |
| 244 | + Saves the processed data and split information. |
| 245 | +
|
| 246 | + Args: |
| 247 | + data_df (pd.DataFrame): Data with GO labels. |
| 248 | + splits_df (pd.DataFrame): Split assignment DataFrame. |
| 249 | + """ |
| 250 | + print("Saving transformed data......") |
| 251 | + go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[ |
| 252 | + self._go_branch |
| 253 | + ](go_branch=self._go_branch.upper(), max_sequence_length=self._MAXLEN) |
| 254 | + |
| 255 | + go_class_instance.save_processed( |
| 256 | + data_df, go_class_instance.processed_main_file_names_dict["data"] |
| 257 | + ) |
| 258 | + print( |
| 259 | + f"{go_class_instance.processed_main_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}" |
| 260 | + ) |
| 261 | + |
| 262 | + splits_df.to_csv( |
| 263 | + os.path.join(go_class_instance.processed_dir_main, "splits.csv"), |
| 264 | + index=False, |
| 265 | + ) |
| 266 | + print(f"splits.csv saved to {go_class_instance.processed_dir_main}") |
| 267 | + |
| 268 | + classes = sorted(self._classes) |
| 269 | + with open( |
| 270 | + os.path.join(go_class_instance.processed_dir_main, "classes.txt"), "wt" |
| 271 | + ) as fout: |
| 272 | + fout.writelines(str(node) + "\n" for node in classes) |
| 273 | + print(f"classes.txt saved to {go_class_instance.processed_dir_main}") |
| 274 | + print("Migration completed!") |
| 275 | + |
| 276 | + |
| 277 | +class Main: |
| 278 | + """ |
| 279 | + Main class to handle the migration process for DeepGo1DataMigration. |
| 280 | +
|
| 281 | + Methods: |
| 282 | + migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]): |
| 283 | + Initiates the migration process for the specified data directory and GO branch. |
| 284 | + """ |
| 285 | + |
| 286 | + @staticmethod |
| 287 | + def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None: |
| 288 | + """ |
| 289 | + Initiates the migration process by creating a DeepGoDataMigration instance |
| 290 | + and invoking its migrate method. |
| 291 | +
|
| 292 | + Args: |
| 293 | + data_dir (str): Directory containing the data files. |
| 294 | + go_branch (Literal["cc", "mf", "bp"]): GO branch to use |
| 295 | + ("cc" for cellular_component, |
| 296 | + "mf" for molecular_function, |
| 297 | + or "bp" for biological_process). |
| 298 | + """ |
| 299 | + DeepGo1DataMigration(data_dir, go_branch).migrate() |
| 300 | + |
| 301 | + |
| 302 | +if __name__ == "__main__": |
| 303 | + # Example: python script_name.py migrate --data_dir="data/deep_go1" --go_branch="mf" |
| 304 | + # --data_dir specifies the directory containing the data files. |
| 305 | + # --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration. |
| 306 | + CLI( |
| 307 | + Main, |
| 308 | + description="DeepGo1DataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).", |
| 309 | + as_positional=False, |
| 310 | + ) |
0 commit comments