Skip to content

Commit 5140ac8

Browse files
committed
fixing
1 parent 7d03d63 commit 5140ac8

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

tests/integrations/test_cloud.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,16 @@ def test_lightning_default_checkpointing(importing, tmp_path):
9696
_cleanup_model(teamspace, model_name)
9797

9898

99+
@pytest.mark.parametrize(
100+
"registry", ["registry", "registry:version:v1", "registry:<model>", "registry:<model>:version:v1"]
101+
)
99102
@pytest.mark.parametrize(
100103
"importing",
101104
[
102105
pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_BELLOW_2_5_1),
103106
pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1),
104107
],
105108
)
106-
@pytest.mark.parametrize(
107-
"registry", ["registry", "registry:version:v1", "registry:<model>", "registry:<model>:version:v1"]
108-
)
109109
@pytest.mark.cloud()
110110
# todo: mock env variables as it would run in studio
111111
def test_lightning_resume(importing, registry, tmp_path):
@@ -124,7 +124,8 @@ def test_lightning_resume(importing, registry, tmp_path):
124124
teamspace, org_team, model_name = _prepare_variables("resume")
125125
upload_model(model=checkpoint_path, name=f"{org_team}/{model_name}")
126126

127-
trainer = Trainer(max_epochs=2, default_root_dir=tmp_path)
127+
trainer_kwargs = {"model_registry": f"{org_team}/{model_name}"} if "<model>" not in registry else {}
128+
trainer = Trainer(max_epochs=2, default_root_dir=tmp_path, **trainer_kwargs)
128129
registry = registry.replace("<model>", f"{org_team}/{model_name}")
129130
trainer.fit(BoringModel(), ckpt_path=registry)
130131

0 commit comments

Comments
 (0)