Skip to content

Commit 1ac9424

Browse files
test: validate scenario being in studio on cloud (#91)
* test: validate scenario being in studio on cloud --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 886b425 commit 1ac9424

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

tests/integrations/test_cloud.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,26 +141,38 @@ def test_lightning_default_checkpointing(importing, in_studio, monkeypatch, tmp_
141141
pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1),
142142
],
143143
)
144+
@pytest.mark.parametrize(
145+
"in_studio",
146+
[
147+
False,
148+
pytest.param(True, marks=pytest.mark.skipif(platform.system() == "Windows", reason="studio is not Windows")),
149+
],
150+
)
144151
@pytest.mark.cloud
145-
# todo: mock env variables as it would run in studio
146-
def test_lightning_plain_resume(trainer_method, registry, importing, tmp_path):
152+
def test_lightning_plain_resume(trainer_method, registry, importing, in_studio, tmp_path, monkeypatch):
147153
if importing == "lightning":
148154
from lightning import Trainer
149155
from lightning.pytorch.demos.boring_classes import BoringModel
150156
elif importing == "pytorch_lightning":
151157
from pytorch_lightning import Trainer
152158
from pytorch_lightning.demos.boring_classes import BoringModel
153159

160+
if in_studio:
161+
# mock env variables as it would run in studio
162+
monkeypatch.setenv("LIGHTNING_ORG", LIT_ORG)
163+
monkeypatch.setenv("LIGHTNING_TEAMSPACE", LIT_TEAMSPACE)
164+
154165
trainer = Trainer(max_epochs=1, limit_train_batches=50, limit_val_batches=20, default_root_dir=tmp_path)
155166
trainer.fit(BoringModel())
156167
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
157168

158169
# model name with random hash
159170
teamspace, org_team, model_name = _prepare_variables(f"resume_{trainer_method}")
160-
upload_model(model=checkpoint_path, name=f"{org_team}/{model_name}")
171+
model_registry = f"{org_team}/{model_name}" if not in_studio else model_name
172+
upload_model(model=checkpoint_path, name=model_registry)
161173
expected_num_versions = 1
162174

163-
trainer_kwargs = {"model_registry": f"{org_team}/{model_name}"} if "<model>" not in registry else {}
175+
trainer_kwargs = {"model_registry": model_registry} if "<model>" not in registry else {}
164176
trainer = Trainer(
165177
max_epochs=2,
166178
default_root_dir=tmp_path,
@@ -170,7 +182,7 @@ def test_lightning_plain_resume(trainer_method, registry, importing, tmp_path):
170182
limit_predict_batches=10,
171183
**trainer_kwargs,
172184
)
173-
registry = registry.replace("<model>", f"{org_team}/{model_name}")
185+
registry = registry.replace("<model>", model_registry)
174186
if trainer_method == "fit":
175187
trainer.fit(BoringModel(), ckpt_path=registry)
176188
if trainer_kwargs:

0 commit comments

Comments
 (0)