Skip to content

Commit 97846fa

Browse files
committed
add prepare_data method to base class
- after doing certain common processing the `prepare_data` method calls the _perform_data_preparation method which contain core logic for data preparation
1 parent bfe137b commit 97846fa

File tree

5 files changed

+13
-24
lines changed

5 files changed

+13
-24
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,10 @@ def __init_subclass__(cls, *args, **kwargs):
138138
super().__init_subclass__(*args, **kwargs)
139139
original_init = cls.__init__
140140

141+
# Creates updated definition for init method
141142
def new_init(self, *args, **kwargs):
142143
original_init(self, *args, **kwargs) # Call the original __init__
143-
if type(self) == cls: # Only run __post_init__ if it's the final class
144+
if type(self) == cls: # Only run method if it's the final class
144145
self._call_data_processing_methods(*args, **kwargs)
145146

146147
cls.__init__ = new_init
@@ -448,6 +449,10 @@ def prepare_data(self, *args, **kwargs) -> None:
448449
return
449450

450451
self._prepare_data_flag += 1
452+
self._perform_data_preparation(*args, **kwargs)
453+
454+
def _perform_data_preparation(self, *args, **kwargs) -> None:
455+
raise NotImplementedError
451456

452457
def setup(self, *args, **kwargs) -> None:
453458
"""
@@ -598,11 +603,10 @@ def __init__(
598603
os.makedirs(self.processed_dir, exist_ok=True)
599604
super(pl.LightningDataModule, self).__init__(**kwargs)
600605

601-
def prepare_data(self):
606+
def _perform_data_preparation(self):
602607
"""
603608
Placeholder for data preparation logic.
604609
"""
605-
super().prepare_data()
606610
for s in self.subsets:
607611
s.prepare_data()
608612

@@ -792,7 +796,7 @@ def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]
792796
return splits_file_path
793797

794798
# ------------------------------ Phase: Prepare data -----------------------------------
795-
def prepare_data(self, *args: Any, **kwargs: Any) -> None:
799+
def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None:
796800
"""
797801
Prepares the data for the dataset.
798802
@@ -811,7 +815,6 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
811815
Returns:
812816
None
813817
"""
814-
super().prepare_data()
815818
print("Checking for processed data in", self.processed_dir_main)
816819

817820
processed_name = self.processed_main_file_names_dict["data"]

chebai/preprocessing/datasets/chebi.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ def download(self):
5959
def raw_file_names(self):
6060
return ["test.pkl", "train.pkl", "validation.pkl"]
6161

62-
def prepare_data(self, *args, **kwargs):
63-
super().prepare_data()
62+
def _perform_data_preparation(self, *args, **kwargs):
6463
print("Check for raw data in", self.raw_dir)
6564
if any(
6665
not os.path.isfile(os.path.join(self.raw_dir, f))
@@ -156,7 +155,7 @@ def __init__(
156155
)
157156

158157
# ------------------------------ Phase: Prepare data -----------------------------------
159-
def prepare_data(self, *args: Any, **kwargs: Any) -> None:
158+
def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None:
160159
"""
161160
Prepares the data for the Chebi dataset.
162161
@@ -177,8 +176,6 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
177176
Returns:
178177
None
179178
"""
180-
super().prepare_data(args, kwargs)
181-
182179
if self.chebi_version_train is not None:
183180
if not os.path.isfile(
184181
os.path.join(

chebai/preprocessing/datasets/deepGO/go_uniprot.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ def _name(self) -> str:
770770
return f"{threshold_part}{self.max_sequence_length}"
771771

772772
# ------------------------------ Phase: Prepare data -----------------------------------
773-
def prepare_data(self, *args: Any, **kwargs: Any) -> None:
773+
def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None:
774774
"""
775775
Checks for the existence of migrated DeepGO data in the specified directory.
776776
Raises an error if the required data file is not found, prompting
@@ -783,11 +783,6 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
783783
Raises:
784784
FileNotFoundError: If the processed data file does not exist.
785785
"""
786-
if self._prepare_data_flag != 1:
787-
return
788-
789-
self._prepare_data_flag += 1
790-
791786
print("Checking for processed data in", self.processed_dir_main)
792787

793788
processed_name = self.processed_main_file_names_dict["data"]

chebai/preprocessing/datasets/deepGO/protein_pretraining.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(self, **kwargs):
5555
)
5656

5757
# ------------------------------ Phase: Prepare data -----------------------------------
58-
def prepare_data(self, *args: Any, **kwargs: Any) -> None:
58+
def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None:
5959
"""
6060
Prepares the data by downloading and parsing Swiss-Prot data if not already available. Saves the processed data
6161
for further use.
@@ -64,11 +64,6 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
6464
*args: Additional positional arguments.
6565
**kwargs: Additional keyword arguments.
6666
"""
67-
if self._prepare_data_flag != 1:
68-
return
69-
70-
self._prepare_data_flag += 1
71-
7267
processed_name = self.processed_main_file_names_dict["data"]
7368
if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)):
7469
print("Missing processed data file (`data.pkl` file)")

chebai/preprocessing/datasets/pubchem.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,10 @@ def processed_file_names(self) -> List[str]:
179179
"""
180180
return ["test.pt", "train.pt", "validation.pt"]
181181

182-
def prepare_data(self, *args, **kwargs):
182+
def _perform_data_preparation(self, *args, **kwargs):
183183
"""
184184
Checks for raw data and downloads if necessary.
185185
"""
186-
super().prepare_data()
187186
print("Check for raw data in", self.raw_dir)
188187
if any(
189188
not os.path.isfile(os.path.join(self.raw_dir, f))

0 commit comments

Comments
 (0)