Skip to content

Commit 0f2f6ba

Browse files
committed
test: integration Trainer's resume {fit|val|test|pred} with Prod
1 parent 2514b8b commit 0f2f6ba

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

tests/integrations/test_cloud.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def test_lightning_default_checkpointing(importing, tmp_path):
9696
_cleanup_model(teamspace, model_name)
9797

9898

99+
@pytest.mark.parametrize("trainer_method", ["fit", "validate", "test", "predict"])
99100
@pytest.mark.parametrize(
100101
"registry", ["registry", "registry:version:v1", "registry:<model>", "registry:<model>:version:v1"]
101102
)
@@ -108,15 +109,15 @@ def test_lightning_default_checkpointing(importing, tmp_path):
108109
)
109110
@pytest.mark.cloud()
110111
# todo: mock env variables as it would run in studio
111-
def test_lightning_resume(importing, registry, tmp_path):
112+
def test_lightning_resume(trainer_method, registry, importing, tmp_path):
112113
if importing == "lightning":
113114
from lightning import Trainer
114115
from lightning.pytorch.demos.boring_classes import BoringModel
115116
elif importing == "pytorch_lightning":
116117
from pytorch_lightning import Trainer
117118
from pytorch_lightning.demos.boring_classes import BoringModel
118119

119-
trainer = Trainer(max_epochs=1, default_root_dir=tmp_path)
120+
trainer = Trainer(max_epochs=1, limit_train_batches=50, limit_val_batches=20, default_root_dir=tmp_path)
120121
trainer.fit(BoringModel())
121122
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
122123

@@ -125,9 +126,26 @@ def test_lightning_resume(importing, registry, tmp_path):
125126
upload_model(model=checkpoint_path, name=f"{org_team}/{model_name}")
126127

127128
trainer_kwargs = {"model_registry": f"{org_team}/{model_name}"} if "<model>" not in registry else {}
128-
trainer = Trainer(max_epochs=2, default_root_dir=tmp_path, **trainer_kwargs)
129+
trainer = Trainer(
130+
max_epochs=2,
131+
default_root_dir=tmp_path,
132+
limit_train_batches=10,
133+
limit_val_batches=10,
134+
limit_test_batches=10,
135+
limit_predict_batches=10,
136+
**trainer_kwargs,
137+
)
129138
registry = registry.replace("<model>", f"{org_team}/{model_name}")
130-
trainer.fit(BoringModel(), ckpt_path=registry)
139+
if trainer_method == "fit":
140+
trainer.fit(BoringModel(), ckpt_path=registry)
141+
elif trainer_method == "validate":
142+
trainer.validate(BoringModel(), ckpt_path=registry)
143+
elif trainer_method == "test":
144+
trainer.test(BoringModel(), ckpt_path=registry)
145+
elif trainer_method == "predict":
146+
trainer.predict(BoringModel(), ckpt_path=registry)
147+
else:
148+
raise ValueError(f"Unknown trainer method: {trainer_method}")
131149

132150
# CLEANING
133151
_cleanup_model(teamspace, model_name)

0 commit comments

Comments
 (0)