Skip to content

Commit 99b5af1

Browse files
committed
add migration for deepgo1 - 2018 paper
1 parent 3e0bae0 commit 99b5af1

File tree

3 files changed

+318
-2
lines changed

3 files changed

+318
-2
lines changed

chebai/preprocessing/migration/deep_go/__init__.py

Whitespace-only changes.
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
import os
2+
from collections import OrderedDict
3+
from typing import List, Literal, Optional
4+
5+
import pandas as pd
6+
from jsonargparse import CLI
7+
8+
from chebai.preprocessing.datasets.go_uniprot import (
9+
GOUniProtOver50,
10+
GOUniProtOver250,
11+
_GOUniProtDataExtractor,
12+
)
13+
14+
15+
class DeepGo1DataMigration:
16+
"""
17+
A class to handle data migration and processing for the DeepGO project.
18+
It migrates the deepGO data to our data structure followed for GO-UniProt data.
19+
20+
It migrates the data of DeepGO model of the below research paper:
21+
Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf,
22+
DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier,
23+
Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668
24+
(https://doi.org/10.1093/bioinformatics/btx624),
25+
26+
Attributes:
27+
_CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes.
28+
_MAXLEN (int): Maximum sequence length for sequences.
29+
_LABELS_START_IDX (int): Starting index for labels in the dataset.
30+
31+
Methods:
32+
__init__(data_dir, go_branch): Initializes the data directory and GO branch.
33+
_load_data(): Loads train, validation, test, and terms data from the specified directory.
34+
_record_splits(): Creates a DataFrame with IDs and their corresponding split.
35+
migrate(): Executes the migration process including data loading, processing, and saving.
36+
_extract_required_data_from_splits(): Extracts required columns from the splits data.
37+
_get_labels_columns(data_df): Generates label columns for the data based on GO terms.
38+
extract_go_id(go_list): Extracts GO IDs from a list.
39+
save_migrated_data(data_df, splits_df): Saves the processed data and splits.
40+
"""
41+
42+
# Number of annotations for each go_branch as per the research paper
43+
_CORRESPONDING_GO_CLASSES = {
44+
"cc": GOUniProtOver50,
45+
"mf": GOUniProtOver50,
46+
"bp": GOUniProtOver250,
47+
}
48+
49+
_MAXLEN = 1002
50+
_LABELS_START_IDX = _GOUniProtDataExtractor._LABELS_START_IDX
51+
52+
def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]):
53+
"""
54+
Initializes the data migration object with a data directory and GO branch.
55+
56+
Args:
57+
data_dir (str): Directory containing the data files.
58+
go_branch (Literal["cc", "mf", "bp"]): GO branch to use (cellular_component, molecular_function, or biological_process).
59+
"""
60+
valid_go_branches = list(self._CORRESPONDING_GO_CLASSES.keys())
61+
if go_branch not in valid_go_branches:
62+
raise ValueError(f"go_branch must be one of {valid_go_branches}")
63+
self._go_branch = go_branch
64+
65+
self._data_dir: str = rf"{data_dir}"
66+
self._train_df: Optional[pd.DataFrame] = None
67+
self._test_df: Optional[pd.DataFrame] = None
68+
self._validation_df: Optional[pd.DataFrame] = None
69+
self._terms_df: Optional[pd.DataFrame] = None
70+
self._classes: Optional[List[str]] = None
71+
72+
def _load_data(self) -> None:
73+
"""
74+
Loads the test, train, validation, and terms data from the pickled files
75+
in the data directory.
76+
"""
77+
try:
78+
print(f"Loading data from {self._data_dir}......")
79+
self._test_df = pd.DataFrame(
80+
pd.read_pickle(
81+
os.path.join(self._data_dir, f"test-{self._go_branch}.pkl")
82+
)
83+
)
84+
self._train_df = pd.DataFrame(
85+
pd.read_pickle(
86+
os.path.join(self._data_dir, f"train-{self._go_branch}.pkl")
87+
)
88+
)
89+
# self._validation_df = pd.DataFrame(
90+
# pd.read_pickle(os.path.join(self._data_dir, f"valid-{self._go_branch}.pkl"))
91+
# )
92+
self._terms_df = pd.DataFrame(
93+
pd.read_pickle(os.path.join(self._data_dir, f"{self._go_branch}.pkl"))
94+
)
95+
96+
except FileNotFoundError as e:
97+
print(f"Error loading data: {e}")
98+
99+
def _record_splits(self) -> pd.DataFrame:
100+
"""
101+
Creates a DataFrame that stores the IDs and their corresponding data splits.
102+
103+
Returns:
104+
pd.DataFrame: A combined DataFrame containing split assignments.
105+
"""
106+
print("Recording splits...")
107+
split_assignment_list: List[pd.DataFrame] = [
108+
pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}),
109+
# pd.DataFrame(
110+
# {"id": self._validation_df["proteins"], "split": "validation"}
111+
# ),
112+
pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}),
113+
]
114+
115+
combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True)
116+
return combined_split_assignment
117+
118+
def migrate(self) -> None:
119+
"""
120+
Executes the data migration by loading, processing, and saving the data.
121+
"""
122+
print("Migration started......")
123+
self._load_data()
124+
if not all(
125+
df is not None
126+
for df in [
127+
self._train_df,
128+
# self._validation_df,
129+
self._test_df,
130+
self._terms_df,
131+
]
132+
):
133+
raise Exception(
134+
"Data splits or terms data is not available in instance variables."
135+
)
136+
splits_df = self._record_splits()
137+
138+
data_with_labels_df = self._extract_required_data_from_splits()
139+
140+
if not all(
141+
var is not None for var in [data_with_labels_df, splits_df, self._classes]
142+
):
143+
raise Exception(
144+
"Data splits or terms data is not available in instance variables."
145+
)
146+
147+
self.save_migrated_data(data_with_labels_df, splits_df)
148+
149+
def _extract_required_data_from_splits(self) -> pd.DataFrame:
150+
"""
151+
Extracts required columns from the combined data splits.
152+
153+
Returns:
154+
pd.DataFrame: A DataFrame containing the essential columns for processing.
155+
"""
156+
print("Combining the data splits with required data..... ")
157+
required_columns = [
158+
"proteins",
159+
"accessions",
160+
"sequences",
161+
# Note: The GO classes here only directly related one, and not transitive GO classes
162+
"gos",
163+
"labels",
164+
]
165+
166+
new_df = pd.concat(
167+
[
168+
self._train_df[required_columns],
169+
# self._validation_df[required_columns],
170+
self._test_df[required_columns],
171+
],
172+
ignore_index=True,
173+
)
174+
new_df["go_ids"] = new_df.apply(
175+
lambda row: self.extract_go_id(row["gos"]), axis=1
176+
)
177+
178+
labels_df = self._get_labels_colums(new_df)
179+
180+
data_df = pd.DataFrame(
181+
OrderedDict(
182+
swiss_id=new_df["proteins"],
183+
accession=new_df["accessions"],
184+
go_ids=new_df["go_ids"],
185+
sequence=new_df["sequences"],
186+
)
187+
)
188+
189+
df = pd.concat([data_df, labels_df], axis=1)
190+
191+
return df
192+
193+
def _get_labels_colums(self, data_df: pd.DataFrame) -> pd.DataFrame:
194+
"""
195+
Generates a DataFrame with one-hot encoded columns for each GO term label,
196+
based on the terms provided in `self._terms_df` and the existing labels in `data_df`.
197+
198+
This method extracts GO IDs from the `functions` column of `self._terms_df`,
199+
creating a list of all unique GO IDs. It then uses this list to create new
200+
columns in the returned DataFrame, where each row has binary values
201+
(0 or 1) indicating the presence of each GO ID in the corresponding entry of
202+
`data_df['labels']`.
203+
204+
Args:
205+
data_df (pd.DataFrame): DataFrame containing data with a 'labels' column,
206+
which holds lists of GO ID labels for each row.
207+
208+
Returns:
209+
pd.DataFrame: A DataFrame with the same index as `data_df` and one column
210+
per GO ID, containing binary values indicating label presence.
211+
"""
212+
print("Generating labels based on terms.pkl file.......")
213+
parsed_go_ids: pd.Series = self._terms_df["functions"].apply(
214+
lambda gos: _GOUniProtDataExtractor._parse_go_id(gos)
215+
)
216+
all_go_ids_list = parsed_go_ids.values.tolist()
217+
self._classes = all_go_ids_list
218+
219+
new_label_columns = pd.DataFrame(
220+
data_df["labels"].tolist(), index=data_df.index, columns=all_go_ids_list
221+
)
222+
223+
return new_label_columns
224+
225+
@staticmethod
226+
def extract_go_id(go_list: List[str]) -> List[int]:
227+
"""
228+
Extracts and parses GO IDs from a list of GO annotations.
229+
230+
Args:
231+
go_list (List[str]): List of GO annotation strings.
232+
233+
Returns:
234+
List[str]: List of parsed GO IDs.
235+
"""
236+
return [
237+
_GOUniProtDataExtractor._parse_go_id(go_id_str) for go_id_str in go_list
238+
]
239+
240+
def save_migrated_data(
241+
self, data_df: pd.DataFrame, splits_df: pd.DataFrame
242+
) -> None:
243+
"""
244+
Saves the processed data and split information.
245+
246+
Args:
247+
data_df (pd.DataFrame): Data with GO labels.
248+
splits_df (pd.DataFrame): Split assignment DataFrame.
249+
"""
250+
print("Saving transformed data......")
251+
go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[
252+
self._go_branch
253+
](go_branch=self._go_branch.upper(), max_sequence_length=self._MAXLEN)
254+
255+
go_class_instance.save_processed(
256+
data_df, go_class_instance.processed_main_file_names_dict["data"]
257+
)
258+
print(
259+
f"{go_class_instance.processed_main_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}"
260+
)
261+
262+
splits_df.to_csv(
263+
os.path.join(go_class_instance.processed_dir_main, "splits.csv"),
264+
index=False,
265+
)
266+
print(f"splits.csv saved to {go_class_instance.processed_dir_main}")
267+
268+
classes = sorted(self._classes)
269+
with open(
270+
os.path.join(go_class_instance.processed_dir_main, "classes.txt"), "wt"
271+
) as fout:
272+
fout.writelines(str(node) + "\n" for node in classes)
273+
print(f"classes.txt saved to {go_class_instance.processed_dir_main}")
274+
print("Migration completed!")
275+
276+
277+
class Main:
278+
"""
279+
Main class to handle the migration process for DeepGo1DataMigration.
280+
281+
Methods:
282+
migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]):
283+
Initiates the migration process for the specified data directory and GO branch.
284+
"""
285+
286+
@staticmethod
287+
def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None:
288+
"""
289+
Initiates the migration process by creating a DeepGoDataMigration instance
290+
and invoking its migrate method.
291+
292+
Args:
293+
data_dir (str): Directory containing the data files.
294+
go_branch (Literal["cc", "mf", "bp"]): GO branch to use
295+
("cc" for cellular_component,
296+
"mf" for molecular_function,
297+
or "bp" for biological_process).
298+
"""
299+
DeepGo1DataMigration(data_dir, go_branch).migrate()
300+
301+
302+
if __name__ == "__main__":
303+
# Example: python script_name.py migrate --data_dir="data/deep_go1" --go_branch="mf"
304+
# --data_dir specifies the directory containing the data files.
305+
# --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration.
306+
CLI(
307+
Main,
308+
description="DeepGo1DataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).",
309+
as_positional=False,
310+
)

chebai/preprocessing/migration/deep_go_data_mirgration.py renamed to chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,17 @@
1212
)
1313

1414

15-
class DeepGoDataMigration:
15+
class DeepGo2DataMigration:
1616
"""
1717
A class to handle data migration and processing for the DeepGO project. It migrates the data from the DeepGO-SE
1818
data structure to our data structure followed for GO-UniProt data.
1919
20+
It migrates the data of DeepGO model of the below research paper:
21+
Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf,
22+
DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier,
23+
Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668
24+
(https://doi.org/10.1093/bioinformatics/btx624),
25+
2026
Attributes:
2127
_CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes.
2228
_MAXLEN (int): Maximum sequence length for sequences.
@@ -283,7 +289,7 @@ def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None:
283289
"mf" for molecular_function,
284290
or "bp" for biological_process).
285291
"""
286-
DeepGoDataMigration(data_dir, go_branch).migrate()
292+
DeepGo2DataMigration(data_dir, go_branch).migrate()
287293

288294

289295
if __name__ == "__main__":

0 commit comments

Comments
 (0)