Skip to content

Commit bc19a21

Browse files
author
sfluegel
committed
add jsonargparse cli to migration, gentle file-not-found handling
1 parent 07340cb commit bc19a21

File tree

2 files changed

+106
-90
lines changed

2 files changed

+106
-90
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -514,9 +514,7 @@ def processed_dir_main(self):
514514
@property
515515
def processed_dir(self):
516516
res = os.path.join(
517-
self.base_dir,
518-
self._name,
519-
"processed",
517+
self.processed_dir_main,
520518
*self.identifier,
521519
)
522520
if self.single_class is None:
@@ -940,15 +938,20 @@ class ChEBIOver100SELFIES(ChEBIOverXSELFIES, ChEBIOver100):
940938

941939

942940
class ChEBIOverXPartial(ChEBIOverX):
943-
"""Dataset that doesn't use the full ChEBI, but extracts are part of ChEBI"""
941+
"""Dataset that doesn't use the full ChEBI, but extracts a part of ChEBI (subclasses of a given top class)"""
944942

945943
def __init__(self, top_class_id: int, **kwargs):
946944
self.top_class_id = top_class_id
947945
super().__init__(**kwargs)
948946

949947
@property
950-
def base_dir(self):
951-
return os.path.join(super().base_dir, f"partial_{self.top_class_id}")
948+
def processed_dir_main(self):
949+
return os.path.join(
950+
self.base_dir,
951+
self._name,
952+
f"partial_{self.top_class_id}",
953+
"processed",
954+
)
952955

953956
def extract_class_hierarchy(self, chebi_path):
954957
with open(chebi_path, encoding="utf-8") as chebi:

chebai/preprocessing/migration/chebi_data_migration.py

Lines changed: 97 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import argparse
22
import os
33
import shutil
4-
from typing import Dict, List, Tuple, Type
4+
from typing import Dict, List, Optional, Tuple, Type
55

66
import pandas as pd
77
import torch
8+
from jsonargparse import CLI
89

9-
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor
10+
from chebai.preprocessing.datasets.chebi import ChEBIOverXPartial, _ChEBIDataExtractor
1011

1112

1213
class ChebiDataMigration:
@@ -17,34 +18,25 @@ class ChebiDataMigration:
1718
__MODULE_PATH (str): The path to the module containing ChEBI classes.
1819
__DATA_ROOT_DIR (str): The root directory for data.
1920
_chebi_cls (_ChEBIDataExtractor): The ChEBI class instance.
20-
_chebi_version (int): The version of the ChEBI dataset.
21-
_single_class (int, optional): The ID of a single class to predict.
22-
_class_name (str): The name of the ChEBI class.
2321
"""
2422

2523
__MODULE_PATH: str = "chebai.preprocessing.datasets.chebi"
2624
__DATA_ROOT_DIR: str = "data"
2725

28-
def __init__(self, class_name: str, chebi_version: int, single_class: int = None):
29-
"""
30-
Initialize the ChebiDataMigration class.
26+
def __init__(self, datamodule: _ChEBIDataExtractor):
27+
self._chebi_cls = datamodule
3128

32-
Args:
33-
class_name (str): The name of the ChEBI class.
34-
chebi_version (int): The version of the ChEBI dataset.
35-
single_class (int, optional): The ID of the single class to predict.
36-
"""
37-
self._chebi_cls: Type[_ChEBIDataExtractor] = self._dynamic_import_chebi_cls(
29+
@classmethod
30+
def from_args(cls, class_name: str, chebi_version: int, single_class: int = None):
31+
chebi_cls: _ChEBIDataExtractor = ChebiDataMigration._dynamic_import_chebi_cls(
3832
class_name, chebi_version, single_class
3933
)
40-
self._chebi_version: int = chebi_version
41-
self._single_class: int = single_class
42-
self._class_name: str = class_name
34+
return cls(chebi_cls)
4335

4436
@classmethod
4537
def _dynamic_import_chebi_cls(
4638
cls, class_name: str, chebi_version: int, single_class: int
47-
) -> Type[_ChEBIDataExtractor]:
39+
) -> _ChEBIDataExtractor:
4840
"""
4941
Dynamically import the ChEBI class.
5042
@@ -67,47 +59,55 @@ def migrate(self) -> None:
6759
"""
6860
os.makedirs(self._chebi_cls.base_dir, exist_ok=True)
6961
print("Migration started.....")
70-
self._migrate_old_raw_data()
62+
old_raw_data_exists = self._migrate_old_raw_data()
7163

7264
# Either we can combine `.pt` split files to form `data.pt` file
7365
# self._migrate_old_processed_data()
7466
# OR
7567
# we can transform `data.pkl` to `data.pt` file (this seems efficient along with less code)
76-
self._chebi_cls.setup_processed()
68+
if old_raw_data_exists:
69+
self._chebi_cls.setup_processed()
70+
else:
71+
self._migrate_old_processed_data()
7772
print("Migration completed.....")
7873

79-
def _migrate_old_raw_data(self) -> None:
74+
def _migrate_old_raw_data(self) -> bool:
8075
"""
8176
Migrate old raw data files to the new data folder structure.
8277
"""
8378
print("-" * 50)
84-
print("Migrating old raw Data....")
79+
print("Migrating old raw data....")
8580

8681
self._copy_file(self._old_raw_dir, self._chebi_cls.raw_dir, "chebi.obo")
8782
self._copy_file(
8883
self._old_raw_dir, self._chebi_cls.processed_dir_main, "classes.txt"
8984
)
9085

91-
old_splits_file_names = {
86+
old_splits_file_names_raw = {
9287
"train": "train.pkl",
9388
"validation": "validation.pkl",
9489
"test": "test.pkl",
9590
}
91+
9692
data_file_path = os.path.join(self._chebi_cls.processed_dir_main, "data.pkl")
9793
if os.path.isfile(data_file_path):
9894
print(f"File {data_file_path} already exists in new data-folder structure")
99-
return
95+
return True
10096

101-
data_df, split_ass_df = self._combine_pkl_splits(
102-
self._old_raw_dir, old_splits_file_names
97+
data_df_split_ass_df = self._combine_pkl_splits(
98+
self._old_raw_dir, old_splits_file_names_raw
10399
)
104-
105-
self._chebi_cls.save_processed(data_df, "data.pkl")
106-
print(f"File {data_file_path} saved to new data-folder structure")
107-
108-
split_file = os.path.join(self._chebi_cls.processed_dir_main, "splits.csv")
109-
split_ass_df.to_csv(split_file) # overwrites the files with same name
110-
print(f"File {split_file} saved to new data-folder structure")
100+
if data_df_split_ass_df is not None:
101+
data_df = data_df_split_ass_df[0]
102+
split_ass_df = data_df_split_ass_df[1]
103+
self._chebi_cls.save_processed(data_df, "data.pkl")
104+
print(f"File {data_file_path} saved to new data-folder structure")
105+
106+
split_file = os.path.join(self._chebi_cls.processed_dir_main, "splits.csv")
107+
split_ass_df.to_csv(split_file) # overwrites the files with same name
108+
print(f"File {split_file} saved to new data-folder structure")
109+
return True
110+
return False
111111

112112
def _migrate_old_processed_data(self) -> None:
113113
"""
@@ -130,13 +130,13 @@ def _migrate_old_processed_data(self) -> None:
130130
data_df = self._combine_pt_splits(
131131
self._old_processed_dir, old_splits_file_names
132132
)
133-
134-
torch.save(data_df, data_file_path)
135-
print(f"File {data_file_path} saved to new data-folder structure")
133+
if data_df is not None:
134+
torch.save(data_df, data_file_path)
135+
print(f"File {data_file_path} saved to new data-folder structure")
136136

137137
def _combine_pt_splits(
138138
self, old_dir: str, old_splits_file_names: Dict[str, str]
139-
) -> pd.DataFrame:
139+
) -> Optional[pd.DataFrame]:
140140
"""
141141
Combine old `.pt` split files into a single DataFrame.
142142
@@ -147,7 +147,11 @@ def _combine_pt_splits(
147147
Returns:
148148
pd.DataFrame: The combined DataFrame.
149149
"""
150-
self._check_if_old_splits_exists(old_dir, old_splits_file_names)
150+
if not self._check_if_old_splits_exists(old_dir, old_splits_file_names):
151+
print(
152+
f"Missing at least one of [{', '.join(old_splits_file_names.values())}] in {old_dir}"
153+
)
154+
return None
151155

152156
print("Combining `.pt` splits...")
153157
df_list: List[pd.DataFrame] = []
@@ -160,7 +164,7 @@ def _combine_pt_splits(
160164

161165
def _combine_pkl_splits(
162166
self, old_dir: str, old_splits_file_names: Dict[str, str]
163-
) -> Tuple[pd.DataFrame, pd.DataFrame]:
167+
) -> Optional[Tuple[pd.DataFrame, pd.DataFrame]]:
164168
"""
165169
Combine old `.pkl` split files into a single DataFrame and create split assignments.
166170
@@ -171,7 +175,11 @@ def _combine_pkl_splits(
171175
Returns:
172176
Tuple[pd.DataFrame, pd.DataFrame]: The combined DataFrame and split assignments DataFrame.
173177
"""
174-
self._check_if_old_splits_exists(old_dir, old_splits_file_names)
178+
if not self._check_if_old_splits_exists(old_dir, old_splits_file_names):
179+
print(
180+
f"Missing at least one of [{', '.join(old_splits_file_names.values())}] in {old_dir}"
181+
)
182+
return None
175183

176184
df_list: List[pd.DataFrame] = []
177185
split_assignment_list: List[pd.DataFrame] = []
@@ -195,25 +203,19 @@ def _combine_pkl_splits(
195203
@staticmethod
196204
def _check_if_old_splits_exists(
197205
old_dir: str, old_splits_file_names: Dict[str, str]
198-
) -> None:
206+
) -> bool:
199207
"""
200208
Check if the old split files exist in the specified directory.
201209
202210
Args:
203211
old_dir (str): The directory containing the old split files.
204212
old_splits_file_names (Dict[str, str]): A dictionary of split names and file names.
205213
206-
Raises:
207-
FileNotFoundError: If any of the split files do not exist.
208214
"""
209-
if any(
210-
not os.path.isfile(os.path.join(old_dir, file))
215+
return all(
216+
os.path.isfile(os.path.join(old_dir, file))
211217
for file in old_splits_file_names.values()
212-
):
213-
raise FileNotFoundError(
214-
f"One of the split {old_splits_file_names.values()} doesn't exist "
215-
f"in old data-folder structure: {old_dir}"
216-
)
218+
)
217219

218220
@staticmethod
219221
def _copy_file(old_file_dir: str, new_file_dir: str, file_name: str) -> None:
@@ -230,18 +232,19 @@ def _copy_file(old_file_dir: str, new_file_dir: str, file_name: str) -> None:
230232
"""
231233
os.makedirs(new_file_dir, exist_ok=True)
232234
new_file_path = os.path.join(new_file_dir, file_name)
233-
if os.path.isfile(new_file_path):
234-
print(f"File {new_file_path} already exists in new data-folder structure")
235-
return
236-
237235
old_file_path = os.path.join(old_file_dir, file_name)
238-
if not os.path.isfile(old_file_path):
239-
raise FileNotFoundError(
240-
f"File {old_file_path} doesn't exist in old data-folder structure"
236+
237+
if os.path.isfile(new_file_path):
238+
print(
239+
f"Skipping {old_file_path} (file already exists at new location {new_file_path})"
241240
)
241+
return
242242

243-
shutil.copy2(os.path.abspath(old_file_path), os.path.abspath(new_file_path))
244-
print(f"Copied from {old_file_path} to {new_file_path}")
243+
if os.path.isfile(old_file_path):
244+
shutil.copy2(os.path.abspath(old_file_path), os.path.abspath(new_file_path))
245+
print(f"Copied {old_file_path} to {new_file_path}")
246+
else:
247+
print(f"Skipping expected file {old_file_path} (not found)")
245248

246249
@property
247250
def _old_base_dir(self) -> str:
@@ -251,6 +254,13 @@ def _old_base_dir(self) -> str:
251254
Returns:
252255
str: The base directory for the old data.
253256
"""
257+
if isinstance(self._chebi_cls, ChEBIOverXPartial):
258+
return os.path.join(
259+
self.__DATA_ROOT_DIR,
260+
self._chebi_cls._name,
261+
f"chebi_v{self._chebi_cls.chebi_version}",
262+
f"partial_{self._chebi_cls.top_class_id}",
263+
)
254264
return os.path.join(
255265
self.__DATA_ROOT_DIR,
256266
self._chebi_cls._name,
@@ -286,29 +296,32 @@ def _old_raw_dir(self) -> str:
286296
return os.path.join(self._old_base_dir, "raw")
287297

288298

299+
class Main:
300+
301+
def migrate(
302+
self,
303+
datamodule: Optional[_ChEBIDataExtractor] = None,
304+
class_name: Optional[str] = None,
305+
chebi_version: Optional[int] = None,
306+
single_class: Optional[int] = None,
307+
):
308+
"""
309+
Migrate ChEBI dataset to new structure and handle splits.
310+
311+
Args:
312+
datamodule (Optional[_ChEBIDataExtractor]): The datamodule instance. If not provided, class_name and
313+
chebi_version are required.
314+
class_name (Optional[str]): The name of the ChEBI class.
315+
chebi_version (Optional[int]): The version of the ChEBI dataset.
316+
single_class (Optional[int]): The ID of the single class to predict.
317+
"""
318+
if datamodule is not None:
319+
ChebiDataMigration(datamodule).migrate()
320+
else:
321+
ChebiDataMigration.from_args(
322+
class_name, chebi_version, single_class
323+
).migrate()
324+
325+
289326
if __name__ == "__main__":
290-
parser = argparse.ArgumentParser(
291-
description="Migrate ChEBI dataset to new structure and handle splits."
292-
)
293-
parser.add_argument(
294-
"--chebi_class",
295-
type=str,
296-
required=True,
297-
help="Chebi class name from the `chebai/preprocessing/datasets/chebi.py`",
298-
)
299-
parser.add_argument(
300-
"--chebi_version", type=int, required=True, help="Chebi data version"
301-
)
302-
parser.add_argument(
303-
"--single_class",
304-
type=int,
305-
help="The ID of the single class to predict",
306-
default=None,
307-
)
308-
args = parser.parse_args()
309-
310-
ChebiDataMigration(
311-
class_name=args.chebi_class,
312-
chebi_version=args.chebi_version,
313-
single_class=args.single_class,
314-
).migrate()
327+
CLI(Main)

0 commit comments

Comments
 (0)