Skip to content

Commit 1b8b270

Browse files
committed
migration fix : truncate seq and save data with labels
1 parent f75e30b commit 1b8b270

File tree

1 file changed

+49
-13
lines changed

1 file changed

+49
-13
lines changed

chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,28 @@ class DeepGo2DataMigration:
2020
(https://doi.org/10.1093/bioinformatics/btx624)
2121
"""
2222

23-
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11
24-
_MAXLEN = 1000
2523
_LABELS_START_IDX = DeepGO2MigratedData._LABELS_START_IDX
2624

27-
def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]):
25+
def __init__(
26+
self, data_dir: str, go_branch: Literal["cc", "mf", "bp"], max_len: int = 1000
27+
):
2828
"""
2929
Initializes the data migration object with a data directory and GO branch.
3030
3131
Args:
3232
data_dir (str): Directory containing the data files.
3333
go_branch (Literal["cc", "mf", "bp"]): GO branch to use.
34+
max_len (int): Used to truncate the sequence to this length. Default is 1000.
35+
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11
3436
"""
3537
valid_go_branches = list(DeepGO2MigratedData.GO_BRANCH_MAPPING.keys())
3638
if go_branch not in valid_go_branches:
3739
raise ValueError(f"go_branch must be one of {valid_go_branches}")
3840
self._go_branch = go_branch
3941

4042
self._data_dir: str = os.path.join(rf"{data_dir}", go_branch)
43+
self._max_len: int = max_len
44+
4145
self._train_df: Optional[pd.DataFrame] = None
4246
self._test_df: Optional[pd.DataFrame] = None
4347
self._validation_df: Optional[pd.DataFrame] = None
@@ -74,33 +78,61 @@ def migrate(self) -> None:
7478
"Data splits or terms data is not available in instance variables."
7579
)
7680

77-
self.save_migrated_data(data_df, splits_df)
81+
self.save_migrated_data(data_with_labels_df, splits_df)
7882

7983
def _load_data(self) -> None:
8084
"""
8185
Loads the test, train, validation, and terms data from the pickled files
8286
in the data directory.
8387
"""
88+
8489
try:
8590
print(f"Loading data from directory: {self._data_dir}......")
86-
self._test_df = pd.DataFrame(
87-
pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl"))
91+
self._test_df = self._truncate_sequences(
92+
pd.DataFrame(
93+
pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl"))
94+
)
8895
)
89-
self._train_df = pd.DataFrame(
90-
pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl"))
96+
self._train_df = self._truncate_sequences(
97+
pd.DataFrame(
98+
pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl"))
99+
)
91100
)
92-
self._validation_df = pd.DataFrame(
93-
pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl"))
101+
self._validation_df = self._truncate_sequences(
102+
pd.DataFrame(
103+
pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl"))
104+
)
94105
)
106+
95107
self._terms_df = pd.DataFrame(
96108
pd.read_pickle(os.path.join(self._data_dir, "terms.pkl"))
97109
)
110+
98111
except FileNotFoundError as e:
99112
raise FileNotFoundError(
100113
f"Data file not found in directory: {e}. "
101114
"Please ensure all required files are available in the specified directory."
102115
)
103116

117+
def _truncate_sequences(
118+
self, df: pd.DataFrame, column: str = "sequences"
119+
) -> pd.DataFrame:
120+
"""
121+
Truncate sequences in a specified column of a dataframe to the maximum length.
122+
123+
https://github.com/bio-ontology-research-group/deepgo2/blob/main/train_cnn.py#L206-L217
124+
125+
Args:
126+
df (pd.DataFrame): The input dataframe containing the data to be processed.
127+
column (str, optional): The column containing sequences to truncate.
128+
Defaults to "sequences".
129+
130+
Returns:
131+
pd.DataFrame: The dataframe with sequences truncated to `self._max_len`.
132+
"""
133+
df[column] = df[column].apply(lambda x: x[: self._max_len])
134+
return df
135+
104136
def _record_splits(self) -> pd.DataFrame:
105137
"""
106138
Creates a DataFrame that stores the IDs and their corresponding data splits.
@@ -217,7 +249,7 @@ def save_migrated_data(
217249
print("Saving transformed data......")
218250
deepgo_migr_inst: DeepGO2MigratedData = DeepGO2MigratedData(
219251
go_branch=DeepGO2MigratedData.GO_BRANCH_MAPPING[self._go_branch],
220-
max_sequence_length=self._MAXLEN,
252+
max_sequence_length=self._max_len,
221253
)
222254

223255
# Save data file
@@ -257,7 +289,9 @@ class Main:
257289
"""
258290

259291
@staticmethod
260-
def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None:
292+
def migrate(
293+
data_dir: str, go_branch: Literal["cc", "mf", "bp"], max_len: int = 1000
294+
) -> None:
261295
"""
262296
Initiates the migration process by creating a DeepGoDataMigration instance
263297
and invoking its migrate method.
@@ -268,8 +302,10 @@ def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None:
268302
("cc" for cellular_component,
269303
"mf" for molecular_function,
270304
or "bp" for biological_process).
305+
max_len (int): Used to truncate the sequence to this length. Default is 1000.
306+
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11
271307
"""
272-
DeepGo2DataMigration(data_dir, go_branch).migrate()
308+
DeepGo2DataMigration(data_dir, go_branch, max_len).migrate()
273309

274310

275311
if __name__ == "__main__":

0 commit comments

Comments
 (0)