|
1 | | -from typing import List |
| 1 | +import os |
| 2 | +from collections import OrderedDict |
| 3 | +from random import randint |
| 4 | +from typing import List, Literal |
2 | 5 |
|
3 | 6 | import pandas as pd |
| 7 | +from jsonargparse import CLI |
4 | 8 |
|
5 | | -# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L18-L22 |
6 | | -NAMESPACES = { |
7 | | - "cc": "cellular_component", |
8 | | - "mf": "molecular_function", |
9 | | - "bp": "biological_process", |
10 | | -} |
11 | | - |
12 | | -# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 |
13 | | -MAXLEN = 1000 |
14 | | - |
15 | | - |
16 | | -def load_data(data_dir): |
17 | | - test_df = pd.DataFrame(pd.read_pickle("test_data.pkl")) |
18 | | - train_df = pd.DataFrame(pd.read_pickle("train_data.pkl")) |
19 | | - validation_df = pd.DataFrame(pd.read_pickle("valid_data.pkl")) |
20 | | - |
21 | | - required_columns = [ |
22 | | - "proteins", |
23 | | - "accessions", |
24 | | - "sequences", |
25 | | - # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L45-L58 |
26 | | - "exp_annotations", # Directly associated GO ids |
27 | | - # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69 |
28 | | - "prop_annotations", # Transitively associated GO ids |
29 | | - ] |
30 | | - |
31 | | - new_df = pd.concat( |
32 | | - [ |
33 | | - train_df[required_columns], |
34 | | - validation_df[required_columns], |
35 | | - test_df[required_columns], |
36 | | - ], |
37 | | - ignore_index=True, |
38 | | - ) |
39 | | - # Generate splits.csv file to store ids of each corresponding split |
40 | | - split_assignment_list: List[pd.DataFrame] = [ |
41 | | - pd.DataFrame({"id": train_df["proteins"], "split": "train"}), |
42 | | - pd.DataFrame({"id": validation_df["proteins"], "split": "validation"}), |
43 | | - pd.DataFrame({"id": test_df["proteins"], "split": "test"}), |
44 | | - ] |
| 9 | +from chebai.preprocessing.datasets.go_uniprot import ( |
| 10 | + GOUniProtOver50, |
| 11 | + GOUniProtOver250, |
| 12 | + _GOUniProtDataExtractor, |
| 13 | +) |
| 14 | + |
| 15 | + |
| 16 | +class DeepGoDataMigration: |
| 17 | + """ |
| 18 | + A class to handle data migration and processing for the DeepGO project. It migrates the data from the DeepGO-SE |
| 19 | + data structure to our data structure followed for GO-UniProt data. |
| 20 | +
|
| 21 | + Attributes: |
| 22 | + _CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes. |
| 23 | + _MAXLEN (int): Maximum sequence length for sequences. |
| 24 | + _LABELS_START_IDX (int): Starting index for labels in the dataset. |
| 25 | +
|
| 26 | + Methods: |
| 27 | + __init__(data_dir, go_branch): Initializes the data directory and GO branch. |
| 28 | + _load_data(): Loads train, validation, test, and terms data from the specified directory. |
| 29 | + _record_splits(): Creates a DataFrame with IDs and their corresponding split. |
| 30 | + migrate(): Executes the migration process including data loading, processing, and saving. |
| 31 | + _extract_required_data_from_splits(): Extracts required columns from the splits data. |
| 32 | + _generate_labels(data_df): Generates label columns for the data based on GO terms. |
| 33 | + extract_go_id(go_list): Extracts GO IDs from a list. |
| 34 | + save_migrated_data(data_df, splits_df): Saves the processed data and splits. |
| 35 | + """ |
| 36 | + |
| 37 | + # Link for the namespaces convention used for GO branch |
| 38 | + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L18-L22 |
| 39 | + _CORRESPONDING_GO_CLASSES = { |
| 40 | + "cc": GOUniProtOver50, |
| 41 | + "mf": GOUniProtOver50, |
| 42 | + "bp": GOUniProtOver250, |
| 43 | + } |
| 44 | + |
| 45 | + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 |
| 46 | + _MAXLEN = 1000 |
| 47 | + _LABELS_START_IDX = _GOUniProtDataExtractor._LABELS_START_IDX |
| 48 | + |
| 49 | + def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): |
| 50 | + """ |
| 51 | + Initializes the data migration object with a data directory and GO branch. |
| 52 | +
|
| 53 | + Args: |
| 54 | + data_dir (str): Directory containing the data files. |
| 55 | + go_branch (Literal["cc", "mf", "bp"]): GO branch to use (cellular_component, molecular_function, or biological_process). |
| 56 | + """ |
| 57 | + valid_go_branches = list(self._CORRESPONDING_GO_CLASSES.keys()) |
| 58 | + if go_branch not in valid_go_branches: |
| 59 | + raise ValueError(f"go_branch must be one of {valid_go_branches}") |
| 60 | + self._go_branch = go_branch |
| 61 | + |
| 62 | + self._data_dir = os.path.join(data_dir, go_branch) |
| 63 | + self._train_df: pd.DataFrame = None |
| 64 | + self._test_df: pd.DataFrame = None |
| 65 | + self._validation_df: pd.DataFrame = None |
| 66 | + self._terms_df: pd.DataFrame = None |
| 67 | + self._classes: List[str] = None |
| 68 | + |
| 69 | + def _load_data(self) -> None: |
| 70 | + """ |
| 71 | + Loads the test, train, validation, and terms data from the pickled files |
| 72 | + in the data directory. |
| 73 | + """ |
| 74 | + try: |
| 75 | + print(f"Loading data from {self._data_dir}......") |
| 76 | + self._test_df = pd.DataFrame( |
| 77 | + pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl")) |
| 78 | + ) |
| 79 | + self._train_df = pd.DataFrame( |
| 80 | + pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl")) |
| 81 | + ) |
| 82 | + self._validation_df = pd.DataFrame( |
| 83 | + pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl")) |
| 84 | + ) |
| 85 | + self._terms_df = pd.DataFrame( |
| 86 | + pd.read_pickle(os.path.join(self._data_dir, "terms.pkl")) |
| 87 | + ) |
| 88 | + except FileNotFoundError as e: |
| 89 | + print(f"Error loading data: {e}") |
| 90 | + |
| 91 | + def _record_splits(self) -> pd.DataFrame: |
| 92 | + """ |
| 93 | + Creates a DataFrame that stores the IDs and their corresponding data splits. |
| 94 | +
|
| 95 | + Returns: |
| 96 | + pd.DataFrame: A combined DataFrame containing split assignments. |
| 97 | + """ |
| 98 | + print("Recording splits...") |
| 99 | + split_assignment_list: List[pd.DataFrame] = [ |
| 100 | + pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), |
| 101 | + pd.DataFrame( |
| 102 | + {"id": self._validation_df["proteins"], "split": "validation"} |
| 103 | + ), |
| 104 | + pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}), |
| 105 | + ] |
| 106 | + |
| 107 | + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) |
| 108 | + return combined_split_assignment |
| 109 | + |
| 110 | + def migrate(self) -> None: |
| 111 | + """ |
| 112 | + Executes the data migration by loading, processing, and saving the data. |
| 113 | + """ |
| 114 | + print("Migration started......") |
| 115 | + self._load_data() |
| 116 | + if not all( |
| 117 | + [self._train_df, self._validation_df, self._test_df, self._terms_df] |
| 118 | + ): |
| 119 | + raise Exception( |
| 120 | + "Data splits or terms data is not available in instance variables." |
| 121 | + ) |
| 122 | + splits_df = self._record_splits() |
| 123 | + |
| 124 | + data_df = self._extract_required_data_from_splits() |
| 125 | + data_with_labels_df = self._generate_labels(data_df) |
| 126 | + |
| 127 | + if not all([data_with_labels_df, splits_df, self._classes]): |
| 128 | + raise Exception( |
| 129 | + "Data splits or terms data is not available in instance variables." |
| 130 | + ) |
| 131 | + |
| 132 | + self.save_migrated_data(data_df, splits_df) |
| 133 | + |
| 134 | + def _extract_required_data_from_splits(self) -> pd.DataFrame: |
| 135 | + """ |
| 136 | + Extracts required columns from the combined data splits. |
| 137 | +
|
| 138 | + Returns: |
| 139 | + pd.DataFrame: A DataFrame containing the essential columns for processing. |
| 140 | + """ |
| 141 | + print("Combining the data splits with required data..... ") |
| 142 | + required_columns = [ |
| 143 | + "proteins", |
| 144 | + "accessions", |
| 145 | + "sequences", |
| 146 | + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L45-L58 |
| 147 | + "exp_annotations", # Directly associated GO ids |
| 148 | + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69 |
| 149 | + "prop_annotations", # Transitively associated GO ids |
| 150 | + ] |
| 151 | + |
| 152 | + new_df = pd.concat( |
| 153 | + [ |
| 154 | + self._train_df[required_columns], |
| 155 | + self._validation_df[required_columns], |
| 156 | + self._test_df[required_columns], |
| 157 | + ], |
| 158 | + ignore_index=True, |
| 159 | + ) |
| 160 | + new_df["go_ids"] = new_df.apply( |
| 161 | + lambda row: self.extract_go_id(row["exp_annotations"]) |
| 162 | + + self.extract_go_id(row["prop_annotations"]), |
| 163 | + axis=1, |
| 164 | + ) |
45 | 165 |
|
46 | | - combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) |
| 166 | + data_df = pd.DataFrame( |
| 167 | + OrderedDict( |
| 168 | + swiss_id=new_df["proteins"], |
| 169 | + accession=new_df["accessions"], |
| 170 | + go_ids=new_df["go_ids"], |
| 171 | + sequence=new_df["sequences"], |
| 172 | + ) |
| 173 | + ) |
| 174 | + return data_df |
47 | 175 |
|
| 176 | + def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: |
| 177 | + """ |
| 178 | + Generates label columns for each GO term in the dataset. |
48 | 179 |
|
49 | | -def save_data(data_dir, data_df): |
50 | | - pass |
| 180 | + Args: |
| 181 | + data_df (pd.DataFrame): DataFrame containing data with GO IDs. |
| 182 | +
|
| 183 | + Returns: |
| 184 | + pd.DataFrame: DataFrame with new label columns. |
| 185 | + """ |
| 186 | + print("Generating labels based on terms.pkl file.......") |
| 187 | + parsed_go_ids: pd.Series = self._terms_df.apply( |
| 188 | + lambda row: self.extract_go_id(row["gos"]) |
| 189 | + ) |
| 190 | + all_go_ids_list = parsed_go_ids.values.tolist() |
| 191 | + self._classes = all_go_ids_list |
| 192 | + new_label_columns = pd.DataFrame( |
| 193 | + False, index=data_df.index, columns=all_go_ids_list |
| 194 | + ) |
| 195 | + data_df = pd.concat([data_df, new_label_columns], axis=1) |
| 196 | + |
| 197 | + for index, row in data_df.iterrows(): |
| 198 | + for go_id in row["go_ids"]: |
| 199 | + if go_id in data_df.columns: |
| 200 | + data_df.at[index, go_id] = True |
| 201 | + |
| 202 | + data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)] |
| 203 | + return data_df |
| 204 | + |
| 205 | + @staticmethod |
| 206 | + def extract_go_id(go_list: List[str]) -> List[str]: |
| 207 | + """ |
| 208 | + Extracts and parses GO IDs from a list of GO annotations. |
| 209 | +
|
| 210 | + Args: |
| 211 | + go_list (List[str]): List of GO annotation strings. |
| 212 | +
|
| 213 | + Returns: |
| 214 | + List[str]: List of parsed GO IDs. |
| 215 | + """ |
| 216 | + return [ |
| 217 | + _GOUniProtDataExtractor._parse_go_id(go_id_str) for go_id_str in go_list |
| 218 | + ] |
| 219 | + |
| 220 | + def save_migrated_data( |
| 221 | + self, data_df: pd.DataFrame, splits_df: pd.DataFrame |
| 222 | + ) -> None: |
| 223 | + """ |
| 224 | + Saves the processed data and split information. |
| 225 | +
|
| 226 | + Args: |
| 227 | + data_df (pd.DataFrame): Data with GO labels. |
| 228 | + splits_df (pd.DataFrame): Split assignment DataFrame. |
| 229 | + """ |
| 230 | + print("Saving transformed data......") |
| 231 | + go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[ |
| 232 | + self._go_branch |
| 233 | + ](go_branch=self._go_branch, max_sequence_length=self._MAXLEN) |
| 234 | + |
| 235 | + go_class_instance.save_processed( |
| 236 | + data_df, go_class_instance.processed_file_names_dict["data"] |
| 237 | + ) |
| 238 | + print( |
| 239 | + f"{go_class_instance.processed_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}" |
| 240 | + ) |
| 241 | + |
| 242 | + splits_df.to_csv( |
| 243 | + os.path.join(go_class_instance.processed_dir_main, "splits.csv"), |
| 244 | + index=False, |
| 245 | + ) |
| 246 | + print(f"splits.csv saved to {go_class_instance.processed_dir_main}") |
| 247 | + |
| 248 | + classes = sorted(self._classes) |
| 249 | + with open( |
| 250 | + os.path.join(go_class_instance.processed_dir_main, "classes.txt"), "wt" |
| 251 | + ) as fout: |
| 252 | + fout.writelines(str(node) + "\n" for node in classes) |
| 253 | + print(f"classes.txt saved to {go_class_instance.processed_dir_main}") |
| 254 | + print("Migration completed!") |
| 255 | + |
| 256 | + |
| 257 | +class Main: |
| 258 | + """ |
| 259 | + Main class to handle the migration process for DeepGoDataMigration. |
| 260 | +
|
| 261 | + Methods: |
| 262 | + migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]): |
| 263 | + Initiates the migration process for the specified data directory and GO branch. |
| 264 | + """ |
| 265 | + |
| 266 | + def migrate(self, data_dir: str, go_branch: str) -> None: |
| 267 | + """ |
| 268 | + Initiates the migration process by creating a DeepGoDataMigration instance |
| 269 | + and invoking its migrate method. |
| 270 | +
|
| 271 | + Args: |
| 272 | + data_dir (str): Directory containing the data files. |
| 273 | + go_branch (Literal["cc", "mf", "bp"]): GO branch to use |
| 274 | + ("cc" for cellular_component, |
| 275 | + "mf" for molecular_function, |
| 276 | + or "bp" for biological_process). |
| 277 | + """ |
| 278 | + DeepGoDataMigration(data_dir, go_branch).migrate() |
| 279 | + |
| 280 | + |
| 281 | +class Main1: |
| 282 | + def __init__(self, max_prize: int = 100): |
| 283 | + """ |
| 284 | + Args: |
| 285 | + max_prize: Maximum prize that can be awarded. |
| 286 | + """ |
| 287 | + self.max_prize = max_prize |
| 288 | + |
| 289 | + def person(self, name: str, additional_prize: int = 0): |
| 290 | + """ |
| 291 | + Args: |
| 292 | + name: Name of the winner. |
| 293 | + additional_prize: Additional prize that can be added to the prize amount. |
| 294 | + """ |
| 295 | + prize = randint(0, self.max_prize) + additional_prize |
| 296 | + return f"{name} won {prize}€!" |
51 | 297 |
|
52 | 298 |
|
53 | 299 | if __name__ == "__main__": |
54 | | - pass |
| 300 | + # Example: python script_name.py migrate data_dir="data/deep_go_se_training_data" go_branch="bp" |
| 301 | + # --data_dir specifies the directory containing the data files. |
| 302 | + # --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration. |
| 303 | + CLI( |
| 304 | + Main1, |
| 305 | + description="DeepGoDataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).", |
| 306 | + ) |
0 commit comments