diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 1bba5e4ca0da7..b2c14e8812eee 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -55,6 +55,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `LightningCLI` loading of hyperparameters from `ckpt_path` failing for subclass model mode ([#21246](https://github.com/Lightning-AI/pytorch-lightning/pull/21246)) +- Fixed check the init args only when the given frames are in `__init__` method ([#21227](https://github.com/Lightning-AI/pytorch-lightning/pull/21227)) + + --- ## [2.5.5] - 2025-09-05 diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index 829cc7a994b93..64aa0209819ab 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -91,7 +91,7 @@ def get_init_args(frame: types.FrameType) -> dict[str, Any]: # pragma: no-cover def _get_init_args(frame: types.FrameType) -> tuple[Optional[Any], dict[str, Any]]: _, _, _, local_vars = inspect.getargvalues(frame) - if "__class__" not in local_vars: + if "__class__" not in local_vars or frame.f_code.co_name != "__init__": return None, {} cls = local_vars["__class__"] init_parameters = inspect.signature(cls.__init__).parameters diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index 92a07f0a3d05e..9b31d18aa9edb 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -341,6 +341,20 @@ def __init__(obj, *more_args, other_arg=300, **more_kwargs): obj.save_hyperparameters() +class _MetaType(type): + def __call__(cls, *args, **kwargs): + instance = super().__call__(*args, **kwargs) # Create the instance + if hasattr(instance, "_after_init"): + instance._after_init(**kwargs) # Call the method if defined + return instance + + +class MetaTypeBoringModel(CustomBoringModel, metaclass=_MetaType): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + + if _OMEGACONF_AVAILABLE: class DictConfSubClassBoringModel(SubClassBoringModel): @@ -365,6 +379,7 @@ class DictConfSubClassBoringModel: ... pytest.param(DictConfSubClassBoringModel, marks=RunIf(omegaconf=True)), BoringModelWithMixin, BoringModelWithMixinAndInit, + MetaTypeBoringModel, ], ) def test_collect_init_arguments(tmp_path, cls):