Skip to content

Commit 335adbf

Browse files
committed
data_methods should only be called after complete instantiation
- prepare_data and setup method should be called only after the most derived based class is full instantiated
1 parent 1ae6543 commit 335adbf

File tree

1 file changed

+52
-4
lines changed
  • chebai/preprocessing/datasets

1 file changed

+52
-4
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from __future__ import annotations
2+
13
import os
24
import random
35
from abc import ABC, abstractmethod
4-
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
6+
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union
57

68
import lightning as pl
79
import networkx as nx
@@ -20,7 +22,40 @@
2022
from chebai.preprocessing import reader as dr
2123

2224

23-
class XYBaseDataModule(LightningDataModule):
25+
class _InitMeta(type):
26+
"""
27+
Metaclass that ensures a specific method (`_call_data_processing_methods`)
28+
is called after an instance (meaning the most derived class instance) is fully initialized.
29+
30+
Purpose:
31+
- Automatically calls `_call_data_processing_methods(**kwargs)` on instances
32+
of classes that use this metaclass, if the method is defined.
33+
- Ensures additional processing logic is executed immediately after object instantiation.
34+
- Useful in cases where post-initialization processing is required across multiple subclasses.
35+
"""
36+
37+
def __call__(
38+
cls: Type[XYBaseDataModule], *args: Any, **kwargs: Any
39+
) -> XYBaseDataModule:
40+
"""
41+
Overrides the instance creation process to call `_after_init` after the most derived class instance
42+
is initialized.
43+
44+
Args:
45+
cls (Type[XYBaseDataModule]): The class being instantiated.
46+
*args (Any): Positional arguments for the class constructor.
47+
**kwargs (Any): Keyword arguments for the class constructor.
48+
49+
Returns:
50+
XYBaseDataModule: The initialized instance of the class.
51+
"""
52+
instance = super().__call__(*args, **kwargs) # Create the instance
53+
if hasattr(instance, "_after_init"):
54+
instance._after_init(**kwargs) # Call the method if defined
55+
return instance
56+
57+
58+
class XYBaseDataModule(LightningDataModule, metaclass=_InitMeta):
2459
"""
2560
Base class for data modules.
2661
@@ -122,9 +157,22 @@ def __init__(
122157
self._prepare_data_flag = 1
123158
self._setup_data_flag = 1
124159

125-
# Skips data setup in the constructor; methods will be called later according to the CLI workflow.
160+
def _after_init(self, **kwargs):
161+
"""
162+
This method is called after the instantiation of most derived class is completed.
163+
Refer the `_InitMeta` metaclass for more details.
164+
"""
165+
self._call_data_processing_methods(**kwargs)
166+
167+
def _call_data_processing_methods(self, **kwargs) -> None:
168+
"""
169+
Calls data processing methods unless explicitly skipped.
170+
171+
- Skips execution if `_skip_data_methods_on_init` is `True` (e.g., for unit tests).
172+
- Otherwise, calls `prepare_data()` and `setup()` for data preparation.
173+
174+
"""
126175
if kwargs.get("_skip_data_methods_on_init", False):
127-
# This change enables to skip these methods during initialization for unit-testing GitHub CI/CD
128176
return
129177

130178
self.prepare_data()

0 commit comments

Comments
 (0)