|
1 | | -from __future__ import annotations |
2 | | - |
3 | 1 | import os |
4 | 2 | import random |
5 | 3 | from abc import ABC, abstractmethod |
6 | | -from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union |
| 4 | +from dataclasses import dataclass |
| 5 | +from typing import Any, Dict, Generator, List, Optional, Tuple, Union |
7 | 6 |
|
8 | 7 | import lightning as pl |
9 | 8 | import networkx as nx |
|
22 | 21 | from chebai.preprocessing import reader as dr |
23 | 22 |
|
24 | 23 |
|
25 | | -class _InitMeta(type): |
26 | | - """ |
27 | | - Metaclass that ensures a specific method (`_call_data_processing_methods`) |
28 | | - is called after an instance (meaning the most derived class instance) is fully initialized. |
29 | | -
|
30 | | - Purpose: |
31 | | - - Automatically calls `_call_data_processing_methods(**kwargs)` on instances |
32 | | - of classes that use this metaclass, if the method is defined. |
33 | | - - Ensures additional processing logic is executed immediately after object instantiation. |
34 | | - - Useful in cases where post-initialization processing is required across multiple subclasses. |
35 | | - """ |
36 | | - |
37 | | - def __call__( |
38 | | - cls: Type[XYBaseDataModule], *args: Any, **kwargs: Any |
39 | | - ) -> XYBaseDataModule: |
40 | | - """ |
41 | | - Overrides the instance creation process to call `_after_init` after the most derived class instance |
42 | | - is initialized. |
43 | | -
|
44 | | - Args: |
45 | | - cls (Type[XYBaseDataModule]): The class being instantiated. |
46 | | - *args (Any): Positional arguments for the class constructor. |
47 | | - **kwargs (Any): Keyword arguments for the class constructor. |
48 | | -
|
49 | | - Returns: |
50 | | - XYBaseDataModule: The initialized instance of the class. |
51 | | - """ |
52 | | - instance = super().__call__(*args, **kwargs) # Create the instance |
53 | | - if hasattr(instance, "_after_init"): |
54 | | - instance._after_init(**kwargs) # Call the method if defined |
55 | | - return instance |
56 | | - |
57 | | - |
58 | | -class XYBaseDataModule(LightningDataModule, metaclass=_InitMeta): |
| 24 | +@dataclass |
| 25 | +class XYBaseDataModule(LightningDataModule): |
59 | 26 | """ |
60 | 27 | Base class for data modules. |
61 | 28 |
|
@@ -157,26 +124,47 @@ def __init__( |
157 | 124 | self._prepare_data_flag = 1 |
158 | 125 | self._setup_data_flag = 1 |
159 | 126 |
|
160 | | - def _after_init(self, **kwargs): |
| 127 | + def __init_subclass__(cls, *args, **kwargs): |
161 | 128 | """ |
162 | | - This method is called after the instantiation of most derived class is completed. |
163 | | - Refer the `_InitMeta` metaclass for more details. |
| 129 | + This method ensures that the '_call_data_processing_methods' is called only for the final subclass |
| 130 | + in the class hierarchy. It overrides the default __init__ behavior to add custom initialization logic. |
| 131 | +
|
| 132 | + - The method saves the original `__init__` method of the class and then defines a new `__init__` method. |
| 133 | + - This new `__init__` method calls the original `__init__` method of the class and then checks if the |
| 134 | + current class is the final subclass (i.e., not a subclass of a subclass). |
| 135 | + - If it's the final class, it invokes the `_call_data_processing_methods` method to perform any necessary |
| 136 | + data processing tasks. |
164 | 137 | """ |
165 | | - self._call_data_processing_methods(**kwargs) |
| 138 | + super().__init_subclass__(*args, **kwargs) |
| 139 | + original_init = cls.__init__ |
| 140 | + |
| 141 | + def new_init(self, *args, **kwargs): |
| 142 | + original_init(self, *args, **kwargs) # Call the original __init__ |
| 143 | + if type(self) == cls: # Only run __post_init__ if it's the final class |
| 144 | + self._call_data_processing_methods(*args, **kwargs) |
| 145 | + |
| 146 | + cls.__init__ = new_init |
166 | 147 |
|
167 | | - def _call_data_processing_methods(self, **kwargs) -> None: |
| 148 | + def _call_data_processing_methods(self, *args, **kwargs) -> None: |
168 | 149 | """ |
169 | 150 | Calls data processing methods unless explicitly skipped. |
170 | 151 |
|
171 | 152 | - Skips execution if `_skip_data_methods_on_init` is `True` (e.g., for unit tests). |
172 | 153 | - Otherwise, calls `prepare_data()` and `setup()` for data preparation. |
173 | 154 |
|
| 155 | + Note: This method is called after the instantiation of most derived class is completed. |
174 | 156 | """ |
175 | 157 | if kwargs.get("_skip_data_methods_on_init", False): |
| 158 | + print( |
| 159 | + f"Skipping data methods of class '{os.path.join(self.base_dir, self._name)}' during initialization" |
| 160 | + ) |
176 | 161 | return |
177 | 162 |
|
178 | | - self.prepare_data() |
179 | | - self.setup() |
| 163 | + print( |
| 164 | + f"Calling data method of class {os.path.join(self.base_dir, self._name)} during initialization" |
| 165 | + ) |
| 166 | + self.prepare_data(*args, **kwargs) |
| 167 | + self.setup(*args, **kwargs) |
180 | 168 |
|
181 | 169 | @property |
182 | 170 | def num_of_labels(self): |
@@ -455,13 +443,13 @@ def predict_dataloader( |
455 | 443 | """ |
456 | 444 | return self.dataloader(self.prediction_kind, shuffle=False, **kwargs) |
457 | 445 |
|
458 | | - def prepare_data(self) -> None: |
| 446 | + def prepare_data(self, *args, **kwargs) -> None: |
459 | 447 | if self._prepare_data_flag != 1: |
460 | 448 | return |
461 | 449 |
|
462 | 450 | self._prepare_data_flag += 1 |
463 | 451 |
|
464 | | - def setup(self, **kwargs): |
| 452 | + def setup(self, *args, **kwargs) -> None: |
465 | 453 | """ |
466 | 454 | Setup the data module. |
467 | 455 |
|
|
0 commit comments