diff --git a/tests/integrations/test_cloud.py b/tests/integrations/test_cloud.py index d0600e6..bc60490 100644 --- a/tests/integrations/test_cloud.py +++ b/tests/integrations/test_cloud.py @@ -14,6 +14,13 @@ LIT_TEAMSPACE = "LitModels" +def _prepare_variables(test_name: str) -> tuple[Teamspace, str, str]: + model_name = f"litmodels_test_integrations_{test_name}+{os.urandom(8).hex()}" + teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None) + org_team = f"{teamspace.owner.name}/{teamspace.name}" + return teamspace, org_team, model_name + + def _cleanup_model(teamspace: Teamspace, model_name: str) -> None: """Cleanup model from the teamspace.""" client = GridRestClient() @@ -35,9 +42,7 @@ def test_upload_download_model(tmp_path): f.write("dummy") # model name with random hash - model_name = f"litmodels_test_integrations+{os.urandom(8).hex()}" - teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None) - org_team = f"{teamspace.owner.name}/{teamspace.name}" + teamspace, org_team, model_name = _prepare_variables("upload_download") out = StringIO() with redirect_stdout(out): @@ -78,9 +83,7 @@ def test_lightning_default_checkpointing(importing, tmp_path): from pytorch_lightning.demos.boring_classes import BoringModel # model name with random hash - model_name = f"litmodels_test_integrations+{os.urandom(8).hex()}" - teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None) - org_team = f"{teamspace.owner.name}/{teamspace.name}" + teamspace, org_team, model_name = _prepare_variables("default_checkpoint") trainer = Trainer( max_epochs=2, @@ -91,3 +94,40 @@ def test_lightning_default_checkpointing(importing, tmp_path): # CLEANING _cleanup_model(teamspace, model_name) + + +@pytest.mark.parametrize( + "registry", ["registry", "registry:version:v1", "registry:", "registry::version:v1"] +) +@pytest.mark.parametrize( + "importing", + [ + pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_BELLOW_2_5_1), + pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1), + ], +) +@pytest.mark.cloud() +# todo: mock env variables as it would run in studio +def test_lightning_resume(importing, registry, tmp_path): + if importing == "lightning": + from lightning import Trainer + from lightning.pytorch.demos.boring_classes import BoringModel + elif importing == "pytorch_lightning": + from pytorch_lightning import Trainer + from pytorch_lightning.demos.boring_classes import BoringModel + + trainer = Trainer(max_epochs=1, default_root_dir=tmp_path) + trainer.fit(BoringModel()) + checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path") + + # model name with random hash + teamspace, org_team, model_name = _prepare_variables("resume") + upload_model(model=checkpoint_path, name=f"{org_team}/{model_name}") + + trainer_kwargs = {"model_registry": f"{org_team}/{model_name}"} if "" not in registry else {} + trainer = Trainer(max_epochs=2, default_root_dir=tmp_path, **trainer_kwargs) + registry = registry.replace("", f"{org_team}/{model_name}") + trainer.fit(BoringModel(), ckpt_path=registry) + + # CLEANING + _cleanup_model(teamspace, model_name)