11import argparse
22import os
33import shutil
4- from typing import Dict , List , Tuple , Type
4+ from typing import Dict , List , Optional , Tuple , Type
55
66import pandas as pd
77import torch
8+ from jsonargparse import CLI
89
9- from chebai .preprocessing .datasets .chebi import _ChEBIDataExtractor
10+ from chebai .preprocessing .datasets .chebi import ChEBIOverXPartial , _ChEBIDataExtractor
1011
1112
1213class ChebiDataMigration :
@@ -17,34 +18,25 @@ class ChebiDataMigration:
1718 __MODULE_PATH (str): The path to the module containing ChEBI classes.
1819 __DATA_ROOT_DIR (str): The root directory for data.
1920 _chebi_cls (_ChEBIDataExtractor): The ChEBI class instance.
20- _chebi_version (int): The version of the ChEBI dataset.
21- _single_class (int, optional): The ID of a single class to predict.
22- _class_name (str): The name of the ChEBI class.
2321 """
2422
2523 __MODULE_PATH : str = "chebai.preprocessing.datasets.chebi"
2624 __DATA_ROOT_DIR : str = "data"
2725
28- def __init__ (self , class_name : str , chebi_version : int , single_class : int = None ):
29- """
30- Initialize the ChebiDataMigration class.
26+ def __init__ (self , datamodule : _ChEBIDataExtractor ):
27+ self ._chebi_cls = datamodule
3128
32- Args:
33- class_name (str): The name of the ChEBI class.
34- chebi_version (int): The version of the ChEBI dataset.
35- single_class (int, optional): The ID of the single class to predict.
36- """
37- self ._chebi_cls : Type [_ChEBIDataExtractor ] = self ._dynamic_import_chebi_cls (
29+ @classmethod
30+ def from_args (cls , class_name : str , chebi_version : int , single_class : int = None ):
31+ chebi_cls : _ChEBIDataExtractor = ChebiDataMigration ._dynamic_import_chebi_cls (
3832 class_name , chebi_version , single_class
3933 )
40- self ._chebi_version : int = chebi_version
41- self ._single_class : int = single_class
42- self ._class_name : str = class_name
34+ return cls (chebi_cls )
4335
4436 @classmethod
4537 def _dynamic_import_chebi_cls (
4638 cls , class_name : str , chebi_version : int , single_class : int
47- ) -> Type [ _ChEBIDataExtractor ] :
39+ ) -> _ChEBIDataExtractor :
4840 """
4941 Dynamically import the ChEBI class.
5042
@@ -67,47 +59,55 @@ def migrate(self) -> None:
6759 """
6860 os .makedirs (self ._chebi_cls .base_dir , exist_ok = True )
6961 print ("Migration started....." )
70- self ._migrate_old_raw_data ()
62+ old_raw_data_exists = self ._migrate_old_raw_data ()
7163
7264 # Either we can combine `.pt` split files to form `data.pt` file
7365 # self._migrate_old_processed_data()
7466 # OR
7567 # we can transform `data.pkl` to `data.pt` file (this seems efficient along with less code)
76- self ._chebi_cls .setup_processed ()
68+ if old_raw_data_exists :
69+ self ._chebi_cls .setup_processed ()
70+ else :
71+ self ._migrate_old_processed_data ()
7772 print ("Migration completed....." )
7873
79- def _migrate_old_raw_data (self ) -> None :
74+ def _migrate_old_raw_data (self ) -> bool :
8075 """
8176 Migrate old raw data files to the new data folder structure.
8277 """
8378 print ("-" * 50 )
84- print ("Migrating old raw Data ...." )
79+ print ("Migrating old raw data ...." )
8580
8681 self ._copy_file (self ._old_raw_dir , self ._chebi_cls .raw_dir , "chebi.obo" )
8782 self ._copy_file (
8883 self ._old_raw_dir , self ._chebi_cls .processed_dir_main , "classes.txt"
8984 )
9085
91- old_splits_file_names = {
86+ old_splits_file_names_raw = {
9287 "train" : "train.pkl" ,
9388 "validation" : "validation.pkl" ,
9489 "test" : "test.pkl" ,
9590 }
91+
9692 data_file_path = os .path .join (self ._chebi_cls .processed_dir_main , "data.pkl" )
9793 if os .path .isfile (data_file_path ):
9894 print (f"File { data_file_path } already exists in new data-folder structure" )
99- return
95+ return True
10096
101- data_df , split_ass_df = self ._combine_pkl_splits (
102- self ._old_raw_dir , old_splits_file_names
97+ data_df_split_ass_df = self ._combine_pkl_splits (
98+ self ._old_raw_dir , old_splits_file_names_raw
10399 )
104-
105- self ._chebi_cls .save_processed (data_df , "data.pkl" )
106- print (f"File { data_file_path } saved to new data-folder structure" )
107-
108- split_file = os .path .join (self ._chebi_cls .processed_dir_main , "splits.csv" )
109- split_ass_df .to_csv (split_file ) # overwrites the files with same name
110- print (f"File { split_file } saved to new data-folder structure" )
100+ if data_df_split_ass_df is not None :
101+ data_df = data_df_split_ass_df [0 ]
102+ split_ass_df = data_df_split_ass_df [1 ]
103+ self ._chebi_cls .save_processed (data_df , "data.pkl" )
104+ print (f"File { data_file_path } saved to new data-folder structure" )
105+
106+ split_file = os .path .join (self ._chebi_cls .processed_dir_main , "splits.csv" )
107+ split_ass_df .to_csv (split_file ) # overwrites the files with same name
108+ print (f"File { split_file } saved to new data-folder structure" )
109+ return True
110+ return False
111111
112112 def _migrate_old_processed_data (self ) -> None :
113113 """
@@ -130,13 +130,13 @@ def _migrate_old_processed_data(self) -> None:
130130 data_df = self ._combine_pt_splits (
131131 self ._old_processed_dir , old_splits_file_names
132132 )
133-
134- torch .save (data_df , data_file_path )
135- print (f"File { data_file_path } saved to new data-folder structure" )
133+ if data_df is not None :
134+ torch .save (data_df , data_file_path )
135+ print (f"File { data_file_path } saved to new data-folder structure" )
136136
137137 def _combine_pt_splits (
138138 self , old_dir : str , old_splits_file_names : Dict [str , str ]
139- ) -> pd .DataFrame :
139+ ) -> Optional [ pd .DataFrame ] :
140140 """
141141 Combine old `.pt` split files into a single DataFrame.
142142
@@ -147,7 +147,11 @@ def _combine_pt_splits(
147147 Returns:
148148 pd.DataFrame: The combined DataFrame.
149149 """
150- self ._check_if_old_splits_exists (old_dir , old_splits_file_names )
150+ if not self ._check_if_old_splits_exists (old_dir , old_splits_file_names ):
151+ print (
152+ f"Missing at least one of [{ ', ' .join (old_splits_file_names .values ())} ] in { old_dir } "
153+ )
154+ return None
151155
152156 print ("Combining `.pt` splits..." )
153157 df_list : List [pd .DataFrame ] = []
@@ -160,7 +164,7 @@ def _combine_pt_splits(
160164
161165 def _combine_pkl_splits (
162166 self , old_dir : str , old_splits_file_names : Dict [str , str ]
163- ) -> Tuple [pd .DataFrame , pd .DataFrame ]:
167+ ) -> Optional [ Tuple [pd .DataFrame , pd .DataFrame ] ]:
164168 """
165169 Combine old `.pkl` split files into a single DataFrame and create split assignments.
166170
@@ -171,7 +175,11 @@ def _combine_pkl_splits(
171175 Returns:
172176 Tuple[pd.DataFrame, pd.DataFrame]: The combined DataFrame and split assignments DataFrame.
173177 """
174- self ._check_if_old_splits_exists (old_dir , old_splits_file_names )
178+ if not self ._check_if_old_splits_exists (old_dir , old_splits_file_names ):
179+ print (
180+ f"Missing at least one of [{ ', ' .join (old_splits_file_names .values ())} ] in { old_dir } "
181+ )
182+ return None
175183
176184 df_list : List [pd .DataFrame ] = []
177185 split_assignment_list : List [pd .DataFrame ] = []
@@ -195,25 +203,19 @@ def _combine_pkl_splits(
195203 @staticmethod
196204 def _check_if_old_splits_exists (
197205 old_dir : str , old_splits_file_names : Dict [str , str ]
198- ) -> None :
206+ ) -> bool :
199207 """
200208 Check if the old split files exist in the specified directory.
201209
202210 Args:
203211 old_dir (str): The directory containing the old split files.
204212 old_splits_file_names (Dict[str, str]): A dictionary of split names and file names.
205213
206- Raises:
207- FileNotFoundError: If any of the split files do not exist.
208214 """
209- if any (
210- not os .path .isfile (os .path .join (old_dir , file ))
215+ return all (
216+ os .path .isfile (os .path .join (old_dir , file ))
211217 for file in old_splits_file_names .values ()
212- ):
213- raise FileNotFoundError (
214- f"One of the split { old_splits_file_names .values ()} doesn't exist "
215- f"in old data-folder structure: { old_dir } "
216- )
218+ )
217219
218220 @staticmethod
219221 def _copy_file (old_file_dir : str , new_file_dir : str , file_name : str ) -> None :
@@ -230,18 +232,19 @@ def _copy_file(old_file_dir: str, new_file_dir: str, file_name: str) -> None:
230232 """
231233 os .makedirs (new_file_dir , exist_ok = True )
232234 new_file_path = os .path .join (new_file_dir , file_name )
233- if os .path .isfile (new_file_path ):
234- print (f"File { new_file_path } already exists in new data-folder structure" )
235- return
236-
237235 old_file_path = os .path .join (old_file_dir , file_name )
238- if not os .path .isfile (old_file_path ):
239- raise FileNotFoundError (
240- f"File { old_file_path } doesn't exist in old data-folder structure"
236+
237+ if os .path .isfile (new_file_path ):
238+ print (
239+ f"Skipping { old_file_path } (file already exists at new location { new_file_path } )"
241240 )
241+ return
242242
243- shutil .copy2 (os .path .abspath (old_file_path ), os .path .abspath (new_file_path ))
244- print (f"Copied from { old_file_path } to { new_file_path } " )
243+ if os .path .isfile (old_file_path ):
244+ shutil .copy2 (os .path .abspath (old_file_path ), os .path .abspath (new_file_path ))
245+ print (f"Copied { old_file_path } to { new_file_path } " )
246+ else :
247+ print (f"Skipping expected file { old_file_path } (not found)" )
245248
246249 @property
247250 def _old_base_dir (self ) -> str :
@@ -251,6 +254,13 @@ def _old_base_dir(self) -> str:
251254 Returns:
252255 str: The base directory for the old data.
253256 """
257+ if isinstance (self ._chebi_cls , ChEBIOverXPartial ):
258+ return os .path .join (
259+ self .__DATA_ROOT_DIR ,
260+ self ._chebi_cls ._name ,
261+ f"chebi_v{ self ._chebi_cls .chebi_version } " ,
262+ f"partial_{ self ._chebi_cls .top_class_id } " ,
263+ )
254264 return os .path .join (
255265 self .__DATA_ROOT_DIR ,
256266 self ._chebi_cls ._name ,
@@ -286,29 +296,32 @@ def _old_raw_dir(self) -> str:
286296 return os .path .join (self ._old_base_dir , "raw" )
287297
288298
299+ class Main :
300+
301+ def migrate (
302+ self ,
303+ datamodule : Optional [_ChEBIDataExtractor ] = None ,
304+ class_name : Optional [str ] = None ,
305+ chebi_version : Optional [int ] = None ,
306+ single_class : Optional [int ] = None ,
307+ ):
308+ """
309+ Migrate ChEBI dataset to new structure and handle splits.
310+
311+ Args:
312+ datamodule (Optional[_ChEBIDataExtractor]): The datamodule instance. If not provided, class_name and
313+ chebi_version are required.
314+ class_name (Optional[str]): The name of the ChEBI class.
315+ chebi_version (Optional[int]): The version of the ChEBI dataset.
316+ single_class (Optional[int]): The ID of the single class to predict.
317+ """
318+ if datamodule is not None :
319+ ChebiDataMigration (datamodule ).migrate ()
320+ else :
321+ ChebiDataMigration .from_args (
322+ class_name , chebi_version , single_class
323+ ).migrate ()
324+
325+
289326if __name__ == "__main__" :
290- parser = argparse .ArgumentParser (
291- description = "Migrate ChEBI dataset to new structure and handle splits."
292- )
293- parser .add_argument (
294- "--chebi_class" ,
295- type = str ,
296- required = True ,
297- help = "Chebi class name from the `chebai/preprocessing/datasets/chebi.py`" ,
298- )
299- parser .add_argument (
300- "--chebi_version" , type = int , required = True , help = "Chebi data version"
301- )
302- parser .add_argument (
303- "--single_class" ,
304- type = int ,
305- help = "The ID of the single class to predict" ,
306- default = None ,
307- )
308- args = parser .parse_args ()
309-
310- ChebiDataMigration (
311- class_name = args .chebi_class ,
312- chebi_version = args .chebi_version ,
313- single_class = args .single_class ,
314- ).migrate ()
327+ CLI (Main )
0 commit comments