Skip to content

Commit 7d03d63

Browse files
committed
test: integration Trainer's resume fit ckpt with Prod
1 parent 7aa2912 commit 7d03d63

File tree

1 file changed

+45
-6
lines changed

1 file changed

+45
-6
lines changed

tests/integrations/test_cloud.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414
LIT_TEAMSPACE = "LitModels"
1515

1616

17+
def _prepare_variables(test_name: str) -> tuple[Teamspace, str, str]:
18+
model_name = f"litmodels_test_integrations_{test_name}+{os.urandom(8).hex()}"
19+
teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None)
20+
org_team = f"{teamspace.owner.name}/{teamspace.name}"
21+
return teamspace, org_team, model_name
22+
23+
1724
def _cleanup_model(teamspace: Teamspace, model_name: str) -> None:
1825
"""Cleanup model from the teamspace."""
1926
client = GridRestClient()
@@ -35,9 +42,7 @@ def test_upload_download_model(tmp_path):
3542
f.write("dummy")
3643

3744
# model name with random hash
38-
model_name = f"litmodels_test_integrations+{os.urandom(8).hex()}"
39-
teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None)
40-
org_team = f"{teamspace.owner.name}/{teamspace.name}"
45+
teamspace, org_team, model_name = _prepare_variables("upload_download")
4146

4247
out = StringIO()
4348
with redirect_stdout(out):
@@ -78,9 +83,7 @@ def test_lightning_default_checkpointing(importing, tmp_path):
7883
from pytorch_lightning.demos.boring_classes import BoringModel
7984

8085
# model name with random hash
81-
model_name = f"litmodels_test_integrations+{os.urandom(8).hex()}"
82-
teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None)
83-
org_team = f"{teamspace.owner.name}/{teamspace.name}"
86+
teamspace, org_team, model_name = _prepare_variables("default_checkpoint")
8487

8588
trainer = Trainer(
8689
max_epochs=2,
@@ -91,3 +94,39 @@ def test_lightning_default_checkpointing(importing, tmp_path):
9194

9295
# CLEANING
9396
_cleanup_model(teamspace, model_name)
97+
98+
99+
@pytest.mark.parametrize(
100+
"importing",
101+
[
102+
pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_BELLOW_2_5_1),
103+
pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1),
104+
],
105+
)
106+
@pytest.mark.parametrize(
107+
"registry", ["registry", "registry:version:v1", "registry:<model>", "registry:<model>:version:v1"]
108+
)
109+
@pytest.mark.cloud()
110+
# todo: mock env variables as it would run in studio
111+
def test_lightning_resume(importing, registry, tmp_path):
112+
if importing == "lightning":
113+
from lightning import Trainer
114+
from lightning.pytorch.demos.boring_classes import BoringModel
115+
elif importing == "pytorch_lightning":
116+
from pytorch_lightning import Trainer
117+
from pytorch_lightning.demos.boring_classes import BoringModel
118+
119+
trainer = Trainer(max_epochs=1, default_root_dir=tmp_path)
120+
trainer.fit(BoringModel())
121+
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
122+
123+
# model name with random hash
124+
teamspace, org_team, model_name = _prepare_variables("resume")
125+
upload_model(model=checkpoint_path, name=f"{org_team}/{model_name}")
126+
127+
trainer = Trainer(max_epochs=2, default_root_dir=tmp_path)
128+
registry = registry.replace("<model>", f"{org_team}/{model_name}")
129+
trainer.fit(BoringModel(), ckpt_path=registry)
130+
131+
# CLEANING
132+
_cleanup_model(teamspace, model_name)

0 commit comments

Comments
 (0)