Skip to content

Commit ca5461f

Browse files
committed
deepgo se mirgration : add class to migrate
1 parent c6d60cd commit ca5461f

File tree

1 file changed

+297
-45
lines changed

1 file changed

+297
-45
lines changed
Lines changed: 297 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,306 @@
1-
from typing import List
1+
import os
2+
from collections import OrderedDict
3+
from random import randint
4+
from typing import List, Literal
25

36
import pandas as pd
7+
from jsonargparse import CLI
48

5-
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L18-L22
6-
NAMESPACES = {
7-
"cc": "cellular_component",
8-
"mf": "molecular_function",
9-
"bp": "biological_process",
10-
}
11-
12-
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11
13-
MAXLEN = 1000
14-
15-
16-
def load_data(data_dir):
17-
test_df = pd.DataFrame(pd.read_pickle("test_data.pkl"))
18-
train_df = pd.DataFrame(pd.read_pickle("train_data.pkl"))
19-
validation_df = pd.DataFrame(pd.read_pickle("valid_data.pkl"))
20-
21-
required_columns = [
22-
"proteins",
23-
"accessions",
24-
"sequences",
25-
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L45-L58
26-
"exp_annotations", # Directly associated GO ids
27-
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69
28-
"prop_annotations", # Transitively associated GO ids
29-
]
30-
31-
new_df = pd.concat(
32-
[
33-
train_df[required_columns],
34-
validation_df[required_columns],
35-
test_df[required_columns],
36-
],
37-
ignore_index=True,
38-
)
39-
# Generate splits.csv file to store ids of each corresponding split
40-
split_assignment_list: List[pd.DataFrame] = [
41-
pd.DataFrame({"id": train_df["proteins"], "split": "train"}),
42-
pd.DataFrame({"id": validation_df["proteins"], "split": "validation"}),
43-
pd.DataFrame({"id": test_df["proteins"], "split": "test"}),
44-
]
9+
from chebai.preprocessing.datasets.go_uniprot import (
10+
GOUniProtOver50,
11+
GOUniProtOver250,
12+
_GOUniProtDataExtractor,
13+
)
14+
15+
16+
class DeepGoDataMigration:
17+
"""
18+
A class to handle data migration and processing for the DeepGO project. It migrates the data from the DeepGO-SE
19+
data structure to our data structure followed for GO-UniProt data.
20+
21+
Attributes:
22+
_CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes.
23+
_MAXLEN (int): Maximum sequence length for sequences.
24+
_LABELS_START_IDX (int): Starting index for labels in the dataset.
25+
26+
Methods:
27+
__init__(data_dir, go_branch): Initializes the data directory and GO branch.
28+
_load_data(): Loads train, validation, test, and terms data from the specified directory.
29+
_record_splits(): Creates a DataFrame with IDs and their corresponding split.
30+
migrate(): Executes the migration process including data loading, processing, and saving.
31+
_extract_required_data_from_splits(): Extracts required columns from the splits data.
32+
_generate_labels(data_df): Generates label columns for the data based on GO terms.
33+
extract_go_id(go_list): Extracts GO IDs from a list.
34+
save_migrated_data(data_df, splits_df): Saves the processed data and splits.
35+
"""
36+
37+
# Link for the namespaces convention used for GO branch
38+
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L18-L22
39+
_CORRESPONDING_GO_CLASSES = {
40+
"cc": GOUniProtOver50,
41+
"mf": GOUniProtOver50,
42+
"bp": GOUniProtOver250,
43+
}
44+
45+
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11
46+
_MAXLEN = 1000
47+
_LABELS_START_IDX = _GOUniProtDataExtractor._LABELS_START_IDX
48+
49+
def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]):
50+
"""
51+
Initializes the data migration object with a data directory and GO branch.
52+
53+
Args:
54+
data_dir (str): Directory containing the data files.
55+
go_branch (Literal["cc", "mf", "bp"]): GO branch to use (cellular_component, molecular_function, or biological_process).
56+
"""
57+
valid_go_branches = list(self._CORRESPONDING_GO_CLASSES.keys())
58+
if go_branch not in valid_go_branches:
59+
raise ValueError(f"go_branch must be one of {valid_go_branches}")
60+
self._go_branch = go_branch
61+
62+
self._data_dir = os.path.join(data_dir, go_branch)
63+
self._train_df: pd.DataFrame = None
64+
self._test_df: pd.DataFrame = None
65+
self._validation_df: pd.DataFrame = None
66+
self._terms_df: pd.DataFrame = None
67+
self._classes: List[str] = None
68+
69+
def _load_data(self) -> None:
70+
"""
71+
Loads the test, train, validation, and terms data from the pickled files
72+
in the data directory.
73+
"""
74+
try:
75+
print(f"Loading data from {self._data_dir}......")
76+
self._test_df = pd.DataFrame(
77+
pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl"))
78+
)
79+
self._train_df = pd.DataFrame(
80+
pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl"))
81+
)
82+
self._validation_df = pd.DataFrame(
83+
pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl"))
84+
)
85+
self._terms_df = pd.DataFrame(
86+
pd.read_pickle(os.path.join(self._data_dir, "terms.pkl"))
87+
)
88+
except FileNotFoundError as e:
89+
print(f"Error loading data: {e}")
90+
91+
def _record_splits(self) -> pd.DataFrame:
92+
"""
93+
Creates a DataFrame that stores the IDs and their corresponding data splits.
94+
95+
Returns:
96+
pd.DataFrame: A combined DataFrame containing split assignments.
97+
"""
98+
print("Recording splits...")
99+
split_assignment_list: List[pd.DataFrame] = [
100+
pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}),
101+
pd.DataFrame(
102+
{"id": self._validation_df["proteins"], "split": "validation"}
103+
),
104+
pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}),
105+
]
106+
107+
combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True)
108+
return combined_split_assignment
109+
110+
def migrate(self) -> None:
111+
"""
112+
Executes the data migration by loading, processing, and saving the data.
113+
"""
114+
print("Migration started......")
115+
self._load_data()
116+
if not all(
117+
[self._train_df, self._validation_df, self._test_df, self._terms_df]
118+
):
119+
raise Exception(
120+
"Data splits or terms data is not available in instance variables."
121+
)
122+
splits_df = self._record_splits()
123+
124+
data_df = self._extract_required_data_from_splits()
125+
data_with_labels_df = self._generate_labels(data_df)
126+
127+
if not all([data_with_labels_df, splits_df, self._classes]):
128+
raise Exception(
129+
"Data splits or terms data is not available in instance variables."
130+
)
131+
132+
self.save_migrated_data(data_df, splits_df)
133+
134+
def _extract_required_data_from_splits(self) -> pd.DataFrame:
135+
"""
136+
Extracts required columns from the combined data splits.
137+
138+
Returns:
139+
pd.DataFrame: A DataFrame containing the essential columns for processing.
140+
"""
141+
print("Combining the data splits with required data..... ")
142+
required_columns = [
143+
"proteins",
144+
"accessions",
145+
"sequences",
146+
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L45-L58
147+
"exp_annotations", # Directly associated GO ids
148+
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69
149+
"prop_annotations", # Transitively associated GO ids
150+
]
151+
152+
new_df = pd.concat(
153+
[
154+
self._train_df[required_columns],
155+
self._validation_df[required_columns],
156+
self._test_df[required_columns],
157+
],
158+
ignore_index=True,
159+
)
160+
new_df["go_ids"] = new_df.apply(
161+
lambda row: self.extract_go_id(row["exp_annotations"])
162+
+ self.extract_go_id(row["prop_annotations"]),
163+
axis=1,
164+
)
45165

46-
combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True)
166+
data_df = pd.DataFrame(
167+
OrderedDict(
168+
swiss_id=new_df["proteins"],
169+
accession=new_df["accessions"],
170+
go_ids=new_df["go_ids"],
171+
sequence=new_df["sequences"],
172+
)
173+
)
174+
return data_df
47175

176+
def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame:
177+
"""
178+
Generates label columns for each GO term in the dataset.
48179
49-
def save_data(data_dir, data_df):
50-
pass
180+
Args:
181+
data_df (pd.DataFrame): DataFrame containing data with GO IDs.
182+
183+
Returns:
184+
pd.DataFrame: DataFrame with new label columns.
185+
"""
186+
print("Generating labels based on terms.pkl file.......")
187+
parsed_go_ids: pd.Series = self._terms_df.apply(
188+
lambda row: self.extract_go_id(row["gos"])
189+
)
190+
all_go_ids_list = parsed_go_ids.values.tolist()
191+
self._classes = all_go_ids_list
192+
new_label_columns = pd.DataFrame(
193+
False, index=data_df.index, columns=all_go_ids_list
194+
)
195+
data_df = pd.concat([data_df, new_label_columns], axis=1)
196+
197+
for index, row in data_df.iterrows():
198+
for go_id in row["go_ids"]:
199+
if go_id in data_df.columns:
200+
data_df.at[index, go_id] = True
201+
202+
data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)]
203+
return data_df
204+
205+
@staticmethod
206+
def extract_go_id(go_list: List[str]) -> List[str]:
207+
"""
208+
Extracts and parses GO IDs from a list of GO annotations.
209+
210+
Args:
211+
go_list (List[str]): List of GO annotation strings.
212+
213+
Returns:
214+
List[str]: List of parsed GO IDs.
215+
"""
216+
return [
217+
_GOUniProtDataExtractor._parse_go_id(go_id_str) for go_id_str in go_list
218+
]
219+
220+
def save_migrated_data(
221+
self, data_df: pd.DataFrame, splits_df: pd.DataFrame
222+
) -> None:
223+
"""
224+
Saves the processed data and split information.
225+
226+
Args:
227+
data_df (pd.DataFrame): Data with GO labels.
228+
splits_df (pd.DataFrame): Split assignment DataFrame.
229+
"""
230+
print("Saving transformed data......")
231+
go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[
232+
self._go_branch
233+
](go_branch=self._go_branch, max_sequence_length=self._MAXLEN)
234+
235+
go_class_instance.save_processed(
236+
data_df, go_class_instance.processed_file_names_dict["data"]
237+
)
238+
print(
239+
f"{go_class_instance.processed_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}"
240+
)
241+
242+
splits_df.to_csv(
243+
os.path.join(go_class_instance.processed_dir_main, "splits.csv"),
244+
index=False,
245+
)
246+
print(f"splits.csv saved to {go_class_instance.processed_dir_main}")
247+
248+
classes = sorted(self._classes)
249+
with open(
250+
os.path.join(go_class_instance.processed_dir_main, "classes.txt"), "wt"
251+
) as fout:
252+
fout.writelines(str(node) + "\n" for node in classes)
253+
print(f"classes.txt saved to {go_class_instance.processed_dir_main}")
254+
print("Migration completed!")
255+
256+
257+
class Main:
258+
"""
259+
Main class to handle the migration process for DeepGoDataMigration.
260+
261+
Methods:
262+
migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]):
263+
Initiates the migration process for the specified data directory and GO branch.
264+
"""
265+
266+
def migrate(self, data_dir: str, go_branch: str) -> None:
267+
"""
268+
Initiates the migration process by creating a DeepGoDataMigration instance
269+
and invoking its migrate method.
270+
271+
Args:
272+
data_dir (str): Directory containing the data files.
273+
go_branch (Literal["cc", "mf", "bp"]): GO branch to use
274+
("cc" for cellular_component,
275+
"mf" for molecular_function,
276+
or "bp" for biological_process).
277+
"""
278+
DeepGoDataMigration(data_dir, go_branch).migrate()
279+
280+
281+
class Main1:
282+
def __init__(self, max_prize: int = 100):
283+
"""
284+
Args:
285+
max_prize: Maximum prize that can be awarded.
286+
"""
287+
self.max_prize = max_prize
288+
289+
def person(self, name: str, additional_prize: int = 0):
290+
"""
291+
Args:
292+
name: Name of the winner.
293+
additional_prize: Additional prize that can be added to the prize amount.
294+
"""
295+
prize = randint(0, self.max_prize) + additional_prize
296+
return f"{name} won {prize}€!"
51297

52298

53299
if __name__ == "__main__":
54-
pass
300+
# Example: python script_name.py migrate data_dir="data/deep_go_se_training_data" go_branch="bp"
301+
# --data_dir specifies the directory containing the data files.
302+
# --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration.
303+
CLI(
304+
Main1,
305+
description="DeepGoDataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).",
306+
)

0 commit comments

Comments
 (0)