Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/pytorch/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
matplotlib>3.1, <3.10.0
omegaconf >=2.2.3, <2.4.0
hydra-core >=1.2.0, <1.4.0
jsonargparse[signatures] >=4.28.0, <=4.40.0
jsonargparse[signatures] >=4.39.0, <4.40.0
rich >=12.3.0, <14.1.0
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
bitsandbytes >=0.45.2,<0.45.3; platform_system != "Darwin"
6 changes: 5 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593))


- For cross-device local checkpoints, instruct users to install `fsspec>=2025.5.0` if unavailable ([#20780](https://github.com/Lightning-AI/pytorch-lightning/pull/20780))


Expand All @@ -25,7 +26,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed logger_connector has edge case where step can be a float ([#20692](https://github.com/Lightning-AI/pytorch-lightning/issues/20692))
- Fixed `save_hyperparameters` not working correctly with `LightningCLI` when there are parsing links applied on instantiation ([#20777](https://github.com/Lightning-AI/pytorch-lightning/pull/20777))


- Fixed logger_connector has edge case where step can be a float ([#20692](https://github.com/Lightning-AI/pytorch-lightning/pull/20692))


- Fix: Synchronize SIGTERM Handling in DDP to Prevent Deadlocks ([#20825](https://github.com/Lightning-AI/pytorch-lightning/pull/20825))
Expand Down
40 changes: 37 additions & 3 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def __init__(
args: ArgsType = None,
run: bool = True,
auto_configure_optimizers: bool = True,
load_from_checkpoint_support: bool = True,
) -> None:
"""Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are
called / instantiated using a parsed configuration file and / or command line args.
Expand Down Expand Up @@ -367,6 +368,11 @@ def __init__(
``dict`` or ``jsonargparse.Namespace``.
run: Whether subcommands should be added to run a :class:`~lightning.pytorch.trainer.trainer.Trainer`
method. If set to ``False``, the trainer and model classes will be instantiated only.
auto_configure_optimizers: Whether to automatically add default optimizer and lr_scheduler arguments.
load_from_checkpoint_support: Whether ``save_hyperparameters`` should save the original parsed
hyperparameters (instead of what ``__init__`` receives), such that it is possible for
``load_from_checkpoint`` to correctly instantiate classes even when using complex nesting and
dependency injection.

"""
self.save_config_callback = save_config_callback
Expand Down Expand Up @@ -396,7 +402,8 @@ def __init__(

self._set_seed()

self._add_instantiators()
if load_from_checkpoint_support:
self._add_instantiators()
self.before_instantiate_classes()
self.instantiate_classes()
self.after_instantiate_classes()
Expand Down Expand Up @@ -544,11 +551,14 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
else:
self.config = parser.parse_args(args)

def _add_instantiators(self) -> None:
def _dump_config(self) -> None:
if hasattr(self, "config_dump"):
return
self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False, skip_none=False))
if "subcommand" in self.config:
self.config_dump = self.config_dump[self.config.subcommand]

def _add_instantiators(self) -> None:
self.parser.add_instantiator(
_InstantiatorFn(cli=self, key="model"),
_get_module_type(self._model_class),
Expand Down Expand Up @@ -799,12 +809,27 @@ def _get_module_type(value: Union[Callable, type]) -> type:
return value


def _set_dict_nested(data: dict, key: str, value: Any) -> None:
keys = key.split(".")
for k in keys[:-1]:
assert k in data, f"Expected key {key} to be in data"
data = data[k]
data[keys[-1]] = value


class _InstantiatorFn:
def __init__(self, cli: LightningCLI, key: str) -> None:
self.cli = cli
self.key = key

def __call__(self, class_type: type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType:
def __call__(
self,
class_type: type[ModuleType],
*args: Any,
applied_instantiation_links: dict,
**kwargs: Any,
) -> ModuleType:
self.cli._dump_config()
hparams = self.cli.config_dump.get(self.key, {})
if "class_path" in hparams:
# To make hparams backwards compatible, and so that it is the same irrespective of subclass_mode, the
Expand All @@ -815,6 +840,15 @@ def __call__(self, class_type: type[ModuleType], *args: Any, **kwargs: Any) -> M
**hparams.get("init_args", {}),
**hparams.get("dict_kwargs", {}),
}
# get instantiation link target values from kwargs
for key, value in applied_instantiation_links.items():
if not key.startswith(f"{self.key}."):
continue
key = key[len(f"{self.key}.") :]
if key.startswith("init_args."):
key = key[len("init_args.") :]
_set_dict_nested(hparams, key, value)

with _given_hyperparameters_context(
hparams=hparams,
instantiator="lightning.pytorch.cli.instantiate_module",
Expand Down
79 changes: 74 additions & 5 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def __init__(self, activation: torch.nn.Module = None, transform: Optional[list[
class BoringModelRequiredClasses(BoringModel):
def __init__(self, num_classes: int, batch_size: int = 8):
super().__init__()
self.save_hyperparameters()
self.num_classes = num_classes
self.batch_size = batch_size

Expand All @@ -561,35 +562,103 @@ def __init__(self, batch_size: int = 8):
self.num_classes = 5 # only available after instantiation


def test_lightning_cli_link_arguments():
def test_lightning_cli_link_arguments(cleandir):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("data.batch_size", "model.batch_size")
parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate")

cli_args = ["--data.batch_size=12"]
cli_args = ["--data.batch_size=12", "--trainer.max_epochs=1"]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, run=False)

assert cli.model.batch_size == 12
assert cli.model.num_classes == 5

class MyLightningCLI(LightningCLI):
cli.trainer.fit(cli.model)
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
assert hparams_path.is_file()
hparams = yaml.safe_load(hparams_path.read_text())

hparams.pop("_instantiator")
assert hparams == {"batch_size": 12, "num_classes": 5}

class MyLightningCLI2(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("data.batch_size", "model.init_args.batch_size")
parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate")

cli_args[-1] = "--model=tests_pytorch.test_cli.BoringModelRequiredClasses"
cli_args[0] = "--model=tests_pytorch.test_cli.BoringModelRequiredClasses"

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(
cli = MyLightningCLI2(
BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, subclass_mode_model=True, run=False
)

assert cli.model.batch_size == 8
assert cli.model.num_classes == 5

cli.trainer.fit(cli.model)
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
assert hparams_path.is_file()
hparams = yaml.safe_load(hparams_path.read_text())

hparams.pop("_instantiator")
assert hparams == {"batch_size": 8, "num_classes": 5}


class CustomAdam(torch.optim.Adam):
def __init__(self, params, num_classes: Optional[int] = None, **kwargs):
super().__init__(params, **kwargs)


class DeepLinkTargetModel(BoringModel):
def __init__(
self,
optimizer: OptimizerCallable = torch.optim.Adam,
):
super().__init__()
self.save_hyperparameters()
self.optimizer = optimizer

def configure_optimizers(self):
optimizer = self.optimizer(self.parameters())
return {"optimizer": optimizer}


def test_lightning_cli_link_arguments_subcommands_nested_target(cleandir):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments(
"data.num_classes",
"model.init_args.optimizer.init_args.num_classes",
apply_on="instantiate",
)

cli_args = [
"fit",
"--data.batch_size=12",
"--trainer.max_epochs=1",
"--model=tests_pytorch.test_cli.DeepLinkTargetModel",
"--model.optimizer=tests_pytorch.test_cli.CustomAdam",
]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(
DeepLinkTargetModel,
BoringDataModuleBatchSizeAndClasses,
subclass_mode_model=True,
auto_configure_optimizers=False,
)

hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
assert hparams_path.is_file()
hparams = yaml.safe_load(hparams_path.read_text())

assert hparams["optimizer"]["class_path"] == "tests_pytorch.test_cli.CustomAdam"
assert hparams["optimizer"]["init_args"]["num_classes"] == 5


class EarlyExitTestModel(BoringModel):
def on_fit_start(self):
Expand Down
Loading