Skip to content

Commit 3a7598d

Browse files
carmoccalexierule
authored andcommitted
Backwards compatibility for get_init_args (#16851)
1 parent 6822960 commit 3a7598d

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111

1212
- Fixed DDP spawn hang on TPU Pods ([#16844](https://github.com/Lightning-AI/lightning/pull/16844))
1313
- Fixed edge cases in parsing device ids using NVML ([#16795](https://github.com/Lightning-AI/lightning/pull/16795))
14+
- Fixed backwards compatibility for `lightning.pytorch.utilities.parsing.get_init_args` ([#16851](https://github.com/Lightning-AI/lightning/pull/16851))
1415

1516

1617
## [1.9.3] - 2023-02-21

src/pytorch_lightning/utilities/parsing.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,13 @@ def _get_first_if_any(
137137
return n_self, n_args, n_kwargs
138138

139139

140-
def get_init_args(frame: types.FrameType) -> Tuple[Optional[Any], Dict[str, Any]]:
140+
def get_init_args(frame: types.FrameType) -> Dict[str, Any]: # pragma: no-cover
141+
"""For backwards compatibility: #16369."""
142+
_, local_args = _get_init_args(frame)
143+
return local_args
144+
145+
146+
def _get_init_args(frame: types.FrameType) -> Tuple[Optional[Any], Dict[str, Any]]:
141147
_, _, _, local_vars = inspect.getargvalues(frame)
142148
if "__class__" not in local_vars:
143149
return None, {}
@@ -180,7 +186,7 @@ def collect_init_args(
180186
if not isinstance(frame.f_back, types.FrameType):
181187
return path_args
182188

183-
local_self, local_args = get_init_args(frame)
189+
local_self, local_args = _get_init_args(frame)
184190
if "__class__" in local_vars and (not classes or isinstance(local_self, classes)):
185191
# recursive update
186192
path_args.append(local_args)

tests/tests_pytorch/utilities/test_parsing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818

1919
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
2020
from pytorch_lightning.utilities.parsing import (
21+
_get_init_args,
2122
AttributeDict,
2223
clean_namespace,
2324
collect_init_args,
2425
flatten_dict,
25-
get_init_args,
2626
is_picklable,
2727
lightning_getattr,
2828
lightning_hasattr,
@@ -252,7 +252,7 @@ def __init__(self, anyarg, anykw=42, **kwargs):
252252

253253
def get_init_args_wrapper(self):
254254
frame = inspect.currentframe().f_back
255-
self.result = get_init_args(frame)
255+
self.result = _get_init_args(frame)
256256

257257
my_class = AutomaticArgsModel("test", anykw=32, otherkw=123)
258258
assert my_class.result == (my_class, {"anyarg": "test", "anykw": 32, "otherkw": 123})

0 commit comments

Comments
 (0)