Skip to content

Commit 392852d

Browse files
committed
add testing in cli
1 parent 5a0df4d commit 392852d

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

tests/tests_pytorch/test_cli.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
import yaml
3030
from lightning_utilities import compare_version
3131
from lightning_utilities.test.warning import no_warning_call
32-
from tensorboard.backend.event_processing import event_accumulator
33-
from tensorboard.plugins.hparams.plugin_data_pb2 import HParamsPluginData
3432
from torch.optim import SGD
3533
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
3634

@@ -53,6 +51,8 @@
5351
from lightning.pytorch.trainer.states import TrainerFn
5452
from lightning.pytorch.utilities.exceptions import MisconfigurationException
5553
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
54+
from tensorboard.backend.event_processing import event_accumulator
55+
from tensorboard.plugins.hparams.plugin_data_pb2 import HParamsPluginData
5656
from tests_pytorch.helpers.runif import RunIf
5757

5858
if _JSONARGPARSE_SIGNATURES_AVAILABLE:
@@ -1917,3 +1917,29 @@ def __init__(self, main_param: int = 1):
19171917
cli = LightningCLI(MainModule, run=False, parser_kwargs={"parser_mode": "jsonnet"})
19181918

19191919
assert cli.config["model"]["main_param"] == 2
1920+
1921+
1922+
def test_lightning_cli_callback_trainer_default(cleandir):
1923+
"""Check that callbacks passed as trainer_defaults are properly instantiated."""
1924+
with mock.patch("sys.argv", ["any.py"]):
1925+
cli = LightningCLI(
1926+
BoringModel,
1927+
BoringDataModule,
1928+
trainer_defaults={
1929+
"logger": {
1930+
"class_path": "lightning.pytorch.loggers.TensorBoardLogger",
1931+
"init_args": {
1932+
"save_dir": ".",
1933+
"name": "demo",
1934+
},
1935+
},
1936+
"callbacks": {
1937+
"class_path": "lightning.pytorch.callbacks.ModelCheckpoint",
1938+
"init_args": {
1939+
"monitor": "val_loss",
1940+
},
1941+
},
1942+
},
1943+
run=False,
1944+
)
1945+
assert any(isinstance(c, ModelCheckpoint) for c in cli.trainer.callbacks)

0 commit comments

Comments
 (0)