Skip to content

Commit e0a8524

Browse files
committed
deepgo1: further split train set into train and val for
- +migration structure changes
1 parent a15d492 commit e0a8524

File tree

1 file changed

+118
-123
lines changed

1 file changed

+118
-123
lines changed

chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py

Lines changed: 118 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,39 @@
11
import os
22
from collections import OrderedDict
3-
from typing import List, Literal, Optional
3+
from typing import List, Literal, Optional, Tuple
44

55
import pandas as pd
6+
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
67
from jsonargparse import CLI
78

8-
from chebai.preprocessing.datasets.go_uniprot import (
9-
GOUniProtOver50,
10-
GOUniProtOver250,
11-
_GOUniProtDataExtractor,
12-
)
9+
from chebai.preprocessing.datasets.go_uniprot import DeepGO1MigratedData
1310

1411

1512
class DeepGo1DataMigration:
1613
"""
1714
A class to handle data migration and processing for the DeepGO project.
18-
It migrates the deepGO data to our data structure followed for GO-UniProt data.
15+
It migrates the DeepGO data to our data structure followed for GO-UniProt data.
1916
20-
It migrates the data of DeepGO model of the below research paper:
17+
This class handles data from the DeepGO model as described in:
2118
Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf,
2219
DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier,
2320
Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668
24-
(https://doi.org/10.1093/bioinformatics/btx624),
25-
26-
Attributes:
27-
_CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes.
28-
_MAXLEN (int): Maximum sequence length for sequences.
29-
_LABELS_START_IDX (int): Starting index for labels in the dataset.
30-
31-
Methods:
32-
__init__(data_dir, go_branch): Initializes the data directory and GO branch.
33-
_load_data(): Loads train, validation, test, and terms data from the specified directory.
34-
_record_splits(): Creates a DataFrame with IDs and their corresponding split.
35-
migrate(): Executes the migration process including data loading, processing, and saving.
36-
_extract_required_data_from_splits(): Extracts required columns from the splits data.
37-
_get_labels_columns(data_df): Generates label columns for the data based on GO terms.
38-
extract_go_id(go_list): Extracts GO IDs from a list.
39-
save_migrated_data(data_df, splits_df): Saves the processed data and splits.
21+
(https://doi.org/10.1093/bioinformatics/btx624).
4022
"""
4123

42-
# Number of annotations for each go_branch as per the research paper
43-
_CORRESPONDING_GO_CLASSES = {
44-
"cc": GOUniProtOver50,
45-
"mf": GOUniProtOver50,
46-
"bp": GOUniProtOver250,
47-
}
48-
24+
# Max sequence length as per DeepGO1
4925
_MAXLEN = 1002
50-
_LABELS_START_IDX = _GOUniProtDataExtractor._LABELS_START_IDX
26+
_LABELS_START_IDX = DeepGO1MigratedData._LABELS_START_IDX
5127

5228
def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]):
5329
"""
5430
Initializes the data migration object with a data directory and GO branch.
5531
5632
Args:
5733
data_dir (str): Directory containing the data files.
58-
go_branch (Literal["cc", "mf", "bp"]): GO branch to use (cellular_component, molecular_function, or biological_process).
34+
go_branch (Literal["cc", "mf", "bp"]): GO branch to use.
5935
"""
60-
valid_go_branches = list(self._CORRESPONDING_GO_CLASSES.keys())
36+
valid_go_branches = list(DeepGO1MigratedData.GO_BRANCH_MAPPING.keys())
6137
if go_branch not in valid_go_branches:
6238
raise ValueError(f"go_branch must be one of {valid_go_branches}")
6339
self._go_branch = go_branch
@@ -69,49 +45,104 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]):
6945
self._terms_df: Optional[pd.DataFrame] = None
7046
self._classes: Optional[List[str]] = None
7147

48+
def migrate(self) -> None:
49+
"""
50+
Executes the data migration by loading, processing, and saving the data.
51+
"""
52+
print("Starting the migration process...")
53+
self._load_data()
54+
if not all(
55+
df is not None
56+
for df in [
57+
self._train_df,
58+
self._validation_df,
59+
self._test_df,
60+
self._terms_df,
61+
]
62+
):
63+
raise Exception(
64+
"Data splits or terms data is not available in instance variables."
65+
)
66+
splits_df = self._record_splits()
67+
data_with_labels_df = self._extract_required_data_from_splits()
68+
69+
if not all(
70+
var is not None for var in [data_with_labels_df, splits_df, self._classes]
71+
):
72+
raise Exception(
73+
"Data splits or terms data is not available in instance variables."
74+
)
75+
76+
self.save_migrated_data(data_with_labels_df, splits_df)
77+
7278
def _load_data(self) -> None:
7379
"""
7480
Loads the test, train, validation, and terms data from the pickled files
7581
in the data directory.
7682
"""
7783
try:
78-
print(f"Loading data from {self._data_dir}......")
84+
print(f"Loading data files from directory: {self._data_dir}")
7985
self._test_df = pd.DataFrame(
8086
pd.read_pickle(
8187
os.path.join(self._data_dir, f"test-{self._go_branch}.pkl")
8288
)
8389
)
84-
self._train_df = pd.DataFrame(
90+
91+
# DeepGO 1 lacks a validation split, so we will create one by further splitting the training set.
92+
# Although this reduces the training data slightly compared to the original DeepGO setup,
93+
# given the data size, the impact should be minimal.
94+
train_df = pd.DataFrame(
8595
pd.read_pickle(
8696
os.path.join(self._data_dir, f"train-{self._go_branch}.pkl")
8797
)
8898
)
89-
# self._validation_df = pd.DataFrame(
90-
# pd.read_pickle(os.path.join(self._data_dir, f"valid-{self._go_branch}.pkl"))
91-
# )
92-
93-
# DeepGO1 data does not include a separate validation split, but our data structure requires one.
94-
# To accommodate this, we will create a placeholder validation split by duplicating a small subset of the
95-
# training data. However, to ensure a fair comparison with DeepGO1, we will retain the full training set
96-
# without creating an exclusive validation split from it.
97-
# Therefore, any metrics calculated on this placeholder validation set should be disregarded, as they do not
98-
# reflect true validation performance.
99-
self._validation_df = self._train_df[len(self._train_df) - 5 :]
99+
100+
self._train_df, self._validation_df = self._get_train_val_split(train_df)
101+
100102
self._terms_df = pd.DataFrame(
101103
pd.read_pickle(os.path.join(self._data_dir, f"{self._go_branch}.pkl"))
102104
)
103105

104106
except FileNotFoundError as e:
105107
print(f"Error loading data: {e}")
106108

109+
@staticmethod
110+
def _get_train_val_split(
111+
train_df: pd.DataFrame,
112+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
113+
"""
114+
Splits the training data into a smaller training set and a validation set.
115+
116+
Args:
117+
train_df (pd.DataFrame): Original training DataFrame.
118+
119+
Returns:
120+
Tuple[pd.DataFrame, pd.DataFrame]: Training and validation DataFrames.
121+
"""
122+
labels_list_train = train_df["labels"].tolist()
123+
train_split = 0.85
124+
test_size = ((1 - train_split) ** 2) / train_split
125+
126+
splitter = MultilabelStratifiedShuffleSplit(
127+
n_splits=1, test_size=test_size, random_state=42
128+
)
129+
130+
train_indices, validation_indices = next(
131+
splitter.split(labels_list_train, labels_list_train)
132+
)
133+
134+
df_validation = train_df.iloc[validation_indices]
135+
df_train = train_df.iloc[train_indices]
136+
return df_train, df_validation
137+
107138
def _record_splits(self) -> pd.DataFrame:
108139
"""
109140
Creates a DataFrame that stores the IDs and their corresponding data splits.
110141
111142
Returns:
112143
pd.DataFrame: A combined DataFrame containing split assignments.
113144
"""
114-
print("Recording splits...")
145+
print("Recording data splits for train, validation, and test sets.")
115146
split_assignment_list: List[pd.DataFrame] = [
116147
pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}),
117148
pd.DataFrame(
@@ -123,50 +154,18 @@ def _record_splits(self) -> pd.DataFrame:
123154
combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True)
124155
return combined_split_assignment
125156

126-
def migrate(self) -> None:
127-
"""
128-
Executes the data migration by loading, processing, and saving the data.
129-
"""
130-
print("Migration started......")
131-
self._load_data()
132-
if not all(
133-
df is not None
134-
for df in [
135-
self._train_df,
136-
self._validation_df,
137-
self._test_df,
138-
self._terms_df,
139-
]
140-
):
141-
raise Exception(
142-
"Data splits or terms data is not available in instance variables."
143-
)
144-
splits_df = self._record_splits()
145-
146-
data_with_labels_df = self._extract_required_data_from_splits()
147-
148-
if not all(
149-
var is not None for var in [data_with_labels_df, splits_df, self._classes]
150-
):
151-
raise Exception(
152-
"Data splits or terms data is not available in instance variables."
153-
)
154-
155-
self.save_migrated_data(data_with_labels_df, splits_df)
156-
157157
def _extract_required_data_from_splits(self) -> pd.DataFrame:
158158
"""
159159
Extracts required columns from the combined data splits.
160160
161161
Returns:
162162
pd.DataFrame: A DataFrame containing the essential columns for processing.
163163
"""
164-
print("Combining the data splits with required data..... ")
164+
print("Combining data splits into a single DataFrame with required columns.")
165165
required_columns = [
166166
"proteins",
167167
"accessions",
168168
"sequences",
169-
# Note: The GO classes here only directly related one, and not transitive GO classes
170169
"gos",
171170
"labels",
172171
]
@@ -183,7 +182,7 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame:
183182
lambda row: self.extract_go_id(row["gos"]), axis=1
184183
)
185184

186-
labels_df = self._get_labels_colums(new_df)
185+
labels_df = self._get_labels_columns(new_df)
187186

188187
data_df = pd.DataFrame(
189188
OrderedDict(
@@ -198,28 +197,32 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame:
198197

199198
return df
200199

201-
def _get_labels_colums(self, data_df: pd.DataFrame) -> pd.DataFrame:
200+
@staticmethod
201+
def extract_go_id(go_list: List[str]) -> List[int]:
202202
"""
203-
Generates a DataFrame with one-hot encoded columns for each GO term label,
204-
based on the terms provided in `self._terms_df` and the existing labels in `data_df`.
203+
Extracts and parses GO IDs from a list of GO annotations.
205204
206-
This method extracts GO IDs from the `functions` column of `self._terms_df`,
207-
creating a list of all unique GO IDs. It then uses this list to create new
208-
columns in the returned DataFrame, where each row has binary values
209-
(0 or 1) indicating the presence of each GO ID in the corresponding entry of
210-
`data_df['labels']`.
205+
Args:
206+
go_list (List[str]): List of GO annotation strings.
207+
208+
Returns:
209+
List[int]: List of parsed GO IDs.
210+
"""
211+
return [DeepGO1MigratedData._parse_go_id(go_id_str) for go_id_str in go_list]
212+
213+
def _get_labels_columns(self, data_df: pd.DataFrame) -> pd.DataFrame:
214+
"""
215+
Generates columns for labels based on provided selected terms.
211216
212217
Args:
213-
data_df (pd.DataFrame): DataFrame containing data with a 'labels' column,
214-
which holds lists of GO ID labels for each row.
218+
data_df (pd.DataFrame): DataFrame with GO annotations and labels.
215219
216220
Returns:
217-
pd.DataFrame: A DataFrame with the same index as `data_df` and one column
218-
per GO ID, containing binary values indicating label presence.
221+
pd.DataFrame: DataFrame with label columns.
219222
"""
220-
print("Generating labels based on terms.pkl file.......")
223+
print("Generating label columns from provided selected terms.")
221224
parsed_go_ids: pd.Series = self._terms_df["functions"].apply(
222-
lambda gos: _GOUniProtDataExtractor._parse_go_id(gos)
225+
lambda gos: DeepGO1MigratedData._parse_go_id(gos)
223226
)
224227
all_go_ids_list = parsed_go_ids.values.tolist()
225228
self._classes = all_go_ids_list
@@ -230,21 +233,6 @@ def _get_labels_colums(self, data_df: pd.DataFrame) -> pd.DataFrame:
230233

231234
return new_label_columns
232235

233-
@staticmethod
234-
def extract_go_id(go_list: List[str]) -> List[int]:
235-
"""
236-
Extracts and parses GO IDs from a list of GO annotations.
237-
238-
Args:
239-
go_list (List[str]): List of GO annotation strings.
240-
241-
Returns:
242-
List[str]: List of parsed GO IDs.
243-
"""
244-
return [
245-
_GOUniProtDataExtractor._parse_go_id(go_id_str) for go_id_str in go_list
246-
]
247-
248236
def save_migrated_data(
249237
self, data_df: pd.DataFrame, splits_df: pd.DataFrame
250238
) -> None:
@@ -255,31 +243,38 @@ def save_migrated_data(
255243
data_df (pd.DataFrame): Data with GO labels.
256244
splits_df (pd.DataFrame): Split assignment DataFrame.
257245
"""
258-
print("Saving transformed data......")
259-
go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[
260-
self._go_branch
261-
](go_branch=self._go_branch.upper(), max_sequence_length=self._MAXLEN)
246+
print("Saving transformed data files.")
262247

263-
go_class_instance.save_processed(
264-
data_df, go_class_instance.processed_main_file_names_dict["data"]
248+
deepgo_migr_inst: DeepGO1MigratedData = DeepGO1MigratedData(
249+
go_branch=DeepGO1MigratedData.GO_BRANCH_MAPPING[self._go_branch],
250+
max_sequence_length=self._MAXLEN,
251+
)
252+
253+
# Save data file
254+
deepgo_migr_inst.save_processed(
255+
data_df, deepgo_migr_inst.processed_main_file_names_dict["data"]
265256
)
266257
print(
267-
f"{go_class_instance.processed_main_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}"
258+
f"{deepgo_migr_inst.processed_main_file_names_dict['data']} saved to {deepgo_migr_inst.processed_dir_main}"
268259
)
269260

261+
# Save splits file
270262
splits_df.to_csv(
271-
os.path.join(go_class_instance.processed_dir_main, "splits.csv"),
263+
os.path.join(deepgo_migr_inst.processed_dir_main, "splits_deep_go1.csv"),
272264
index=False,
273265
)
274-
print(f"splits.csv saved to {go_class_instance.processed_dir_main}")
266+
print(f"splits_deep_go1.csv saved to {deepgo_migr_inst.processed_dir_main}")
275267

268+
# Save classes file
276269
classes = sorted(self._classes)
277270
with open(
278-
os.path.join(go_class_instance.processed_dir_main, "classes.txt"), "wt"
271+
os.path.join(deepgo_migr_inst.processed_dir_main, "classes_deep_go1.txt"),
272+
"wt",
279273
) as fout:
280274
fout.writelines(str(node) + "\n" for node in classes)
281-
print(f"classes.txt saved to {go_class_instance.processed_dir_main}")
282-
print("Migration completed!")
275+
print(f"classes_deep_go1.txt saved to {deepgo_migr_inst.processed_dir_main}")
276+
277+
print("Migration process completed!")
283278

284279

285280
class Main:

0 commit comments

Comments
 (0)