Skip to content

Commit dfb9430

Browse files
committed
migration: rectify errors
1 parent af54954 commit dfb9430

File tree

1 file changed

+28
-37
lines changed

1 file changed

+28
-37
lines changed

chebai/preprocessing/migration/deep_go_data_mirgration.py

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
from collections import OrderedDict
3-
from random import randint
4-
from typing import List, Literal
3+
from typing import List, Literal, Optional
54

65
import pandas as pd
76
from jsonargparse import CLI
@@ -59,12 +58,12 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]):
5958
raise ValueError(f"go_branch must be one of {valid_go_branches}")
6059
self._go_branch = go_branch
6160

62-
self._data_dir = os.path.join(data_dir, go_branch)
63-
self._train_df: pd.DataFrame = None
64-
self._test_df: pd.DataFrame = None
65-
self._validation_df: pd.DataFrame = None
66-
self._terms_df: pd.DataFrame = None
67-
self._classes: List[str] = None
61+
self._data_dir: str = os.path.join(rf"{data_dir}", go_branch)
62+
self._train_df: Optional[pd.DataFrame] = None
63+
self._test_df: Optional[pd.DataFrame] = None
64+
self._validation_df: Optional[pd.DataFrame] = None
65+
self._terms_df: Optional[pd.DataFrame] = None
66+
self._classes: Optional[List[str]] = None
6867

6968
def _load_data(self) -> None:
7069
"""
@@ -114,7 +113,13 @@ def migrate(self) -> None:
114113
print("Migration started......")
115114
self._load_data()
116115
if not all(
117-
[self._train_df, self._validation_df, self._test_df, self._terms_df]
116+
df is not None
117+
for df in [
118+
self._train_df,
119+
self._validation_df,
120+
self._test_df,
121+
self._terms_df,
122+
]
118123
):
119124
raise Exception(
120125
"Data splits or terms data is not available in instance variables."
@@ -124,7 +129,9 @@ def migrate(self) -> None:
124129
data_df = self._extract_required_data_from_splits()
125130
data_with_labels_df = self._generate_labels(data_df)
126131

127-
if not all([data_with_labels_df, splits_df, self._classes]):
132+
if not all(
133+
var is not None for var in [data_with_labels_df, splits_df, self._classes]
134+
):
128135
raise Exception(
129136
"Data splits or terms data is not available in instance variables."
130137
)
@@ -184,8 +191,8 @@ def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame:
184191
pd.DataFrame: DataFrame with new label columns.
185192
"""
186193
print("Generating labels based on terms.pkl file.......")
187-
parsed_go_ids: pd.Series = self._terms_df.apply(
188-
lambda row: self.extract_go_id(row["gos"])
194+
parsed_go_ids: pd.Series = self._terms_df["gos"].apply(
195+
lambda gos: _GOUniProtDataExtractor._parse_go_id(gos)
189196
)
190197
all_go_ids_list = parsed_go_ids.values.tolist()
191198
self._classes = all_go_ids_list
@@ -203,7 +210,7 @@ def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame:
203210
return data_df
204211

205212
@staticmethod
206-
def extract_go_id(go_list: List[str]) -> List[str]:
213+
def extract_go_id(go_list: List[str]) -> List[int]:
207214
"""
208215
Extracts and parses GO IDs from a list of GO annotations.
209216
@@ -230,13 +237,13 @@ def save_migrated_data(
230237
print("Saving transformed data......")
231238
go_class_instance: _GOUniProtDataExtractor = self._CORRESPONDING_GO_CLASSES[
232239
self._go_branch
233-
](go_branch=self._go_branch, max_sequence_length=self._MAXLEN)
240+
](go_branch=self._go_branch.upper(), max_sequence_length=self._MAXLEN)
234241

235242
go_class_instance.save_processed(
236-
data_df, go_class_instance.processed_file_names_dict["data"]
243+
data_df, go_class_instance.processed_main_file_names_dict["data"]
237244
)
238245
print(
239-
f"{go_class_instance.processed_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}"
246+
f"{go_class_instance.processed_main_file_names_dict['data']} saved to {go_class_instance.processed_dir_main}"
240247
)
241248

242249
splits_df.to_csv(
@@ -263,7 +270,8 @@ class Main:
263270
Initiates the migration process for the specified data directory and GO branch.
264271
"""
265272

266-
def migrate(self, data_dir: str, go_branch: str) -> None:
273+
@staticmethod
274+
def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None:
267275
"""
268276
Initiates the migration process by creating a DeepGoDataMigration instance
269277
and invoking its migrate method.
@@ -278,29 +286,12 @@ def migrate(self, data_dir: str, go_branch: str) -> None:
278286
DeepGoDataMigration(data_dir, go_branch).migrate()
279287

280288

281-
class Main1:
282-
def __init__(self, max_prize: int = 100):
283-
"""
284-
Args:
285-
max_prize: Maximum prize that can be awarded.
286-
"""
287-
self.max_prize = max_prize
288-
289-
def person(self, name: str, additional_prize: int = 0):
290-
"""
291-
Args:
292-
name: Name of the winner.
293-
additional_prize: Additional prize that can be added to the prize amount.
294-
"""
295-
prize = randint(0, self.max_prize) + additional_prize
296-
return f"{name} won {prize}€!"
297-
298-
299289
if __name__ == "__main__":
300-
# Example: python script_name.py migrate data_dir="data/deep_go_se_training_data" go_branch="bp"
290+
# Example: python script_name.py migrate --data_dir="data/deep_go_se_training_data" --go_branch="bp"
301291
# --data_dir specifies the directory containing the data files.
302292
# --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration.
303293
CLI(
304-
Main1,
294+
Main,
305295
description="DeepGoDataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).",
296+
as_positional=False,
306297
)

0 commit comments

Comments
 (0)