Skip to content

Commit bfe137b

Browse files
committed
__init_subclass to call data methods at end of initialization
- @DataClass to avoid conflicts with stack frame of save_hyperparameters()
1 parent 335adbf commit bfe137b

File tree

2 files changed

+35
-46
lines changed

2 files changed

+35
-46
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 34 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
from __future__ import annotations
2-
31
import os
42
import random
53
from abc import ABC, abstractmethod
6-
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union
4+
from dataclasses import dataclass
5+
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
76

87
import lightning as pl
98
import networkx as nx
@@ -22,40 +21,8 @@
2221
from chebai.preprocessing import reader as dr
2322

2423

25-
class _InitMeta(type):
26-
"""
27-
Metaclass that ensures a specific method (`_call_data_processing_methods`)
28-
is called after an instance (meaning the most derived class instance) is fully initialized.
29-
30-
Purpose:
31-
- Automatically calls `_call_data_processing_methods(**kwargs)` on instances
32-
of classes that use this metaclass, if the method is defined.
33-
- Ensures additional processing logic is executed immediately after object instantiation.
34-
- Useful in cases where post-initialization processing is required across multiple subclasses.
35-
"""
36-
37-
def __call__(
38-
cls: Type[XYBaseDataModule], *args: Any, **kwargs: Any
39-
) -> XYBaseDataModule:
40-
"""
41-
Overrides the instance creation process to call `_after_init` after the most derived class instance
42-
is initialized.
43-
44-
Args:
45-
cls (Type[XYBaseDataModule]): The class being instantiated.
46-
*args (Any): Positional arguments for the class constructor.
47-
**kwargs (Any): Keyword arguments for the class constructor.
48-
49-
Returns:
50-
XYBaseDataModule: The initialized instance of the class.
51-
"""
52-
instance = super().__call__(*args, **kwargs) # Create the instance
53-
if hasattr(instance, "_after_init"):
54-
instance._after_init(**kwargs) # Call the method if defined
55-
return instance
56-
57-
58-
class XYBaseDataModule(LightningDataModule, metaclass=_InitMeta):
24+
@dataclass
25+
class XYBaseDataModule(LightningDataModule):
5926
"""
6027
Base class for data modules.
6128
@@ -157,26 +124,47 @@ def __init__(
157124
self._prepare_data_flag = 1
158125
self._setup_data_flag = 1
159126

160-
def _after_init(self, **kwargs):
127+
def __init_subclass__(cls, *args, **kwargs):
161128
"""
162-
This method is called after the instantiation of most derived class is completed.
163-
Refer the `_InitMeta` metaclass for more details.
129+
This method ensures that the '_call_data_processing_methods' is called only for the final subclass
130+
in the class hierarchy. It overrides the default __init__ behavior to add custom initialization logic.
131+
132+
- The method saves the original `__init__` method of the class and then defines a new `__init__` method.
133+
- This new `__init__` method calls the original `__init__` method of the class and then checks if the
134+
current class is the final subclass (i.e., not a subclass of a subclass).
135+
- If it's the final class, it invokes the `_call_data_processing_methods` method to perform any necessary
136+
data processing tasks.
164137
"""
165-
self._call_data_processing_methods(**kwargs)
138+
super().__init_subclass__(*args, **kwargs)
139+
original_init = cls.__init__
140+
141+
def new_init(self, *args, **kwargs):
142+
original_init(self, *args, **kwargs) # Call the original __init__
143+
if type(self) == cls: # Only run __post_init__ if it's the final class
144+
self._call_data_processing_methods(*args, **kwargs)
145+
146+
cls.__init__ = new_init
166147

167-
def _call_data_processing_methods(self, **kwargs) -> None:
148+
def _call_data_processing_methods(self, *args, **kwargs) -> None:
168149
"""
169150
Calls data processing methods unless explicitly skipped.
170151
171152
- Skips execution if `_skip_data_methods_on_init` is `True` (e.g., for unit tests).
172153
- Otherwise, calls `prepare_data()` and `setup()` for data preparation.
173154
155+
Note: This method is called after the instantiation of most derived class is completed.
174156
"""
175157
if kwargs.get("_skip_data_methods_on_init", False):
158+
print(
159+
f"Skipping data methods of class '{os.path.join(self.base_dir, self._name)}' during initialization"
160+
)
176161
return
177162

178-
self.prepare_data()
179-
self.setup()
163+
print(
164+
f"Calling data method of class {os.path.join(self.base_dir, self._name)} during initialization"
165+
)
166+
self.prepare_data(*args, **kwargs)
167+
self.setup(*args, **kwargs)
180168

181169
@property
182170
def num_of_labels(self):
@@ -455,13 +443,13 @@ def predict_dataloader(
455443
"""
456444
return self.dataloader(self.prediction_kind, shuffle=False, **kwargs)
457445

458-
def prepare_data(self) -> None:
446+
def prepare_data(self, *args, **kwargs) -> None:
459447
if self._prepare_data_flag != 1:
460448
return
461449

462450
self._prepare_data_flag += 1
463451

464-
def setup(self, **kwargs):
452+
def setup(self, *args, **kwargs) -> None:
465453
"""
466454
Setup the data module.
467455

chebai/preprocessing/datasets/chebi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def __init__(
149149
# This is to get the data from respective directory related to "chebi_version_train"
150150
_init_kwargs = kwargs
151151
_init_kwargs["chebi_version"] = self.chebi_version_train
152+
_init_kwargs["_skip_data_methods_on_init"] = True
152153
self._chebi_version_train_obj = self.__class__(
153154
single_class=self.single_class,
154155
**_init_kwargs,

0 commit comments

Comments
 (0)