Skip to content

Commit ea066a9

Browse files
committed
fix tests
1 parent 5ac9695 commit ea066a9

File tree

3 files changed

+20
-14
lines changed

3 files changed

+20
-14
lines changed

tests/tests_pytorch/accelerators/test_cpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def setup(self, trainer: "pl.Trainer") -> None:
5353
def restore_checkpoint_after_setup(self) -> bool:
5454
return restore_after_pre_setup
5555

56-
def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> dict[str, Any]:
56+
def load_checkpoint(self, checkpoint_path: Union[str, Path], weights_only: bool) -> dict[str, Any]:
5757
assert self.setup_called == restore_after_pre_setup
58-
return super().load_checkpoint(checkpoint_path)
58+
return super().load_checkpoint(checkpoint_path, weights_only)
5959

6060
model = BoringModel()
6161
trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True)

tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_checkpoint_plugin_called(tmp_path):
6969
assert checkpoint_plugin.remove_checkpoint.call_count == 1
7070

7171
trainer.test(model, ckpt_path=ck.last_model_path)
72-
checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last.ckpt"))
72+
checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last.ckpt"), weights_only=False)
7373

7474
checkpoint_plugin.reset_mock()
7575
ck = ModelCheckpoint(dirpath=tmp_path, save_last=True)
@@ -97,7 +97,7 @@ def test_checkpoint_plugin_called(tmp_path):
9797

9898
trainer.test(model, ckpt_path=ck.last_model_path)
9999
checkpoint_plugin.load_checkpoint.assert_called_once()
100-
checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last-v1.ckpt"))
100+
checkpoint_plugin.load_checkpoint.assert_called_with(str(tmp_path / "last-v1.ckpt"), weights_only=False)
101101

102102

103103
@pytest.mark.flaky(reruns=3)

tests/tests_pytorch/test_cli.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,7 +1222,7 @@ def test_lightning_cli_model_short_arguments():
12221222
):
12231223
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
12241224
assert isinstance(cli.model, BoringModel)
1225-
run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY)
1225+
run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY, ANY)
12261226

12271227
with (
12281228
mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]),
@@ -1250,7 +1250,7 @@ def test_lightning_cli_datamodule_short_arguments():
12501250
):
12511251
cli = LightningCLI(BoringModel, trainer_defaults={"fast_dev_run": 1})
12521252
assert isinstance(cli.datamodule, BoringDataModule)
1253-
run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule, ANY)
1253+
run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule, ANY, ANY)
12541254

12551255
with (
12561256
mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]),
@@ -1271,7 +1271,7 @@ def test_lightning_cli_datamodule_short_arguments():
12711271
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
12721272
assert isinstance(cli.model, BoringModel)
12731273
assert isinstance(cli.datamodule, BoringDataModule)
1274-
run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY)
1274+
run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY, ANY)
12751275

12761276
with (
12771277
mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]),
@@ -1447,7 +1447,7 @@ def test_lightning_cli_config_with_subcommand():
14471447
):
14481448
cli = LightningCLI(BoringModel)
14491449

1450-
test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar")
1450+
test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar", weights_only=False)
14511451
assert cli.trainer.limit_test_batches == 1
14521452

14531453

@@ -1463,7 +1463,9 @@ def test_lightning_cli_config_before_subcommand():
14631463
):
14641464
cli = LightningCLI(BoringModel)
14651465

1466-
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar")
1466+
test_mock.assert_called_once_with(
1467+
cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar", weights_only=False
1468+
)
14671469
assert cli.trainer.limit_test_batches == 1
14681470

14691471
save_config_callback = cli.trainer.callbacks[0]
@@ -1476,7 +1478,7 @@ def test_lightning_cli_config_before_subcommand():
14761478
):
14771479
cli = LightningCLI(BoringModel)
14781480

1479-
validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo")
1481+
validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo", weights_only=False)
14801482
assert cli.trainer.limit_val_batches == 1
14811483

14821484
save_config_callback = cli.trainer.callbacks[0]
@@ -1494,7 +1496,9 @@ def test_lightning_cli_config_before_subcommand_two_configs():
14941496
):
14951497
cli = LightningCLI(BoringModel)
14961498

1497-
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar")
1499+
test_mock.assert_called_once_with(
1500+
cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar", weights_only=False
1501+
)
14981502
assert cli.trainer.limit_test_batches == 1
14991503

15001504
with (
@@ -1503,7 +1507,7 @@ def test_lightning_cli_config_before_subcommand_two_configs():
15031507
):
15041508
cli = LightningCLI(BoringModel)
15051509

1506-
validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo")
1510+
validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo", weights_only=False)
15071511
assert cli.trainer.limit_val_batches == 1
15081512

15091513

@@ -1515,7 +1519,7 @@ def test_lightning_cli_config_after_subcommand():
15151519
):
15161520
cli = LightningCLI(BoringModel)
15171521

1518-
test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar")
1522+
test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar", weights_only=False)
15191523
assert cli.trainer.limit_test_batches == 1
15201524

15211525

@@ -1528,7 +1532,9 @@ def test_lightning_cli_config_before_and_after_subcommand():
15281532
):
15291533
cli = LightningCLI(BoringModel)
15301534

1531-
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar")
1535+
test_mock.assert_called_once_with(
1536+
cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar", weights_only=False
1537+
)
15321538
assert cli.trainer.limit_test_batches == 1
15331539
assert cli.trainer.fast_dev_run == 1
15341540

0 commit comments

Comments
 (0)