|
29 | 29 | import yaml
|
30 | 30 | from lightning_utilities import compare_version
|
31 | 31 | from lightning_utilities.test.warning import no_warning_call
|
32 |
| -from tensorboard.backend.event_processing import event_accumulator |
33 |
| -from tensorboard.plugins.hparams.plugin_data_pb2 import HParamsPluginData |
34 | 32 | from torch.optim import SGD
|
35 | 33 | from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
|
36 | 34 |
|
|
53 | 51 | from lightning.pytorch.trainer.states import TrainerFn
|
54 | 52 | from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
55 | 53 | from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
|
| 54 | +from tensorboard.backend.event_processing import event_accumulator |
| 55 | +from tensorboard.plugins.hparams.plugin_data_pb2 import HParamsPluginData |
56 | 56 | from tests_pytorch.helpers.runif import RunIf
|
57 | 57 |
|
58 | 58 | if _JSONARGPARSE_SIGNATURES_AVAILABLE:
|
@@ -1917,3 +1917,29 @@ def __init__(self, main_param: int = 1):
|
1917 | 1917 | cli = LightningCLI(MainModule, run=False, parser_kwargs={"parser_mode": "jsonnet"})
|
1918 | 1918 |
|
1919 | 1919 | assert cli.config["model"]["main_param"] == 2
|
| 1920 | + |
| 1921 | + |
| 1922 | +def test_lightning_cli_callback_trainer_default(cleandir): |
| 1923 | + """Check that callbacks passed as trainer_defaults are properly instantiated.""" |
| 1924 | + with mock.patch("sys.argv", ["any.py"]): |
| 1925 | + cli = LightningCLI( |
| 1926 | + BoringModel, |
| 1927 | + BoringDataModule, |
| 1928 | + trainer_defaults={ |
| 1929 | + "logger": { |
| 1930 | + "class_path": "lightning.pytorch.loggers.TensorBoardLogger", |
| 1931 | + "init_args": { |
| 1932 | + "save_dir": ".", |
| 1933 | + "name": "demo", |
| 1934 | + }, |
| 1935 | + }, |
| 1936 | + "callbacks": { |
| 1937 | + "class_path": "lightning.pytorch.callbacks.ModelCheckpoint", |
| 1938 | + "init_args": { |
| 1939 | + "monitor": "val_loss", |
| 1940 | + }, |
| 1941 | + }, |
| 1942 | + }, |
| 1943 | + run=False, |
| 1944 | + ) |
| 1945 | + assert any(isinstance(c, ModelCheckpoint) for c in cli.trainer.callbacks) |
0 commit comments