diff --git a/tests/integrations/test_cloud.py b/tests/integrations/test_cloud.py index 9214768..93c7577 100644 --- a/tests/integrations/test_cloud.py +++ b/tests/integrations/test_cloud.py @@ -141,9 +141,15 @@ def test_lightning_default_checkpointing(importing, in_studio, monkeypatch, tmp_ pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1), ], ) +@pytest.mark.parametrize( + "in_studio", + [ + False, + pytest.param(True, marks=pytest.mark.skipif(platform.system() == "Windows", reason="studio is not Windows")), + ], +) @pytest.mark.cloud -# todo: mock env variables as it would run in studio -def test_lightning_plain_resume(trainer_method, registry, importing, tmp_path): +def test_lightning_plain_resume(trainer_method, registry, importing, in_studio, tmp_path, monkeypatch): if importing == "lightning": from lightning import Trainer from lightning.pytorch.demos.boring_classes import BoringModel @@ -151,16 +157,22 @@ def test_lightning_plain_resume(trainer_method, registry, importing, tmp_path): from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel + if in_studio: + # mock env variables as it would run in studio + monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG) + monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE) + trainer = Trainer(max_epochs=1, limit_train_batches=50, limit_val_batches=20, default_root_dir=tmp_path) trainer.fit(BoringModel()) checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path") # model name with random hash teamspace, org_team, model_name = _prepare_variables(f"resume_{trainer_method}") - upload_model(model=checkpoint_path, name=f"{org_team}/{model_name}") + model_registry = f"{org_team}/{model_name}" if not in_studio else model_name + upload_model(model=checkpoint_path, name=model_registry) expected_num_versions = 1 - trainer_kwargs = {"model_registry": f"{org_team}/{model_name}"} if "" not in registry else {} + trainer_kwargs = {"model_registry": model_registry} if "" not in registry else {} trainer = Trainer( max_epochs=2, default_root_dir=tmp_path, @@ -170,7 +182,7 @@ def test_lightning_plain_resume(trainer_method, registry, importing, tmp_path): limit_predict_batches=10, **trainer_kwargs, ) - registry = registry.replace("", f"{org_team}/{model_name}") + registry = registry.replace("", model_registry) if trainer_method == "fit": trainer.fit(BoringModel(), ckpt_path=registry) if trainer_kwargs: