Skip to content

Commit c56064a

Browse files
Bordaethanwharris
andauthored
test: integration Trainer's ddp_spawn with Prod (#62)
* test: integration Trainer's `ddp_spawn` with Prod * Apply suggestions from code review --------- Co-authored-by: Ethan Harris <[email protected]>
1 parent 078df24 commit c56064a

File tree

2 files changed

+52
-7
lines changed

2 files changed

+52
-7
lines changed

_requirements/extra.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
lightning >= 2.0.0
2+
numpy <2.0.0 ; platform_system == "Darwin" # compatibility fix for Torch

tests/integrations/test_cloud.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from contextlib import redirect_stdout
33
from io import StringIO
4+
from typing import Optional
45

56
import pytest
67
from lightning_sdk import Teamspace
@@ -15,13 +16,13 @@
1516

1617

1718
def _prepare_variables(test_name: str) -> tuple[Teamspace, str, str]:
18-
model_name = f"litmodels_test_integrations_{test_name}+{os.urandom(8).hex()}"
19+
model_name = f"ci-test_integrations_{test_name}+{os.urandom(8).hex()}"
1920
teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None)
2021
org_team = f"{teamspace.owner.name}/{teamspace.name}"
2122
return teamspace, org_team, model_name
2223

2324

24-
def _cleanup_model(teamspace: Teamspace, model_name: str) -> None:
25+
def _cleanup_model(teamspace: Teamspace, model_name: str, expected_num_versions: Optional[int] = None) -> None:
2526
"""Cleanup model from the teamspace."""
2627
client = GridRestClient()
2728
# cleaning created models as each test run shall have unique model name
@@ -30,7 +31,10 @@ def _cleanup_model(teamspace: Teamspace, model_name: str) -> None:
3031
project_name=teamspace.name,
3132
model_name=model_name,
3233
)
33-
client.models_store_delete_model(project_id=teamspace.id, model_id=model.id)
34+
if expected_num_versions is not None:
35+
versions = client.models_store_list_model_versions(project_id=model.project_id, model_id=model.id)
36+
assert expected_num_versions == len(versions.versions)
37+
client.models_store_delete_model(project_id=model.project_id, model_id=model.id)
3438

3539

3640
@pytest.mark.cloud()
@@ -62,7 +66,7 @@ def test_upload_download_model(tmp_path):
6266
assert os.path.isfile(os.path.join(tmp_path, file))
6367

6468
# CLEANING
65-
_cleanup_model(teamspace, model_name)
69+
_cleanup_model(teamspace, model_name, expected_num_versions=1)
6670

6771

6872
@pytest.mark.parametrize(
@@ -93,7 +97,7 @@ def test_lightning_default_checkpointing(importing, tmp_path):
9397
trainer.fit(BoringModel())
9498

9599
# CLEANING
96-
_cleanup_model(teamspace, model_name)
100+
_cleanup_model(teamspace, model_name, expected_num_versions=2)
97101

98102

99103
@pytest.mark.parametrize("trainer_method", ["fit", "validate", "test", "predict"])
@@ -109,7 +113,7 @@ def test_lightning_default_checkpointing(importing, tmp_path):
109113
)
110114
@pytest.mark.cloud()
111115
# todo: mock env variables as it would run in studio
112-
def test_lightning_resume(trainer_method, registry, importing, tmp_path):
116+
def test_lightning_plain_resume(trainer_method, registry, importing, tmp_path):
113117
if importing == "lightning":
114118
from lightning import Trainer
115119
from lightning.pytorch.demos.boring_classes import BoringModel
@@ -124,6 +128,7 @@ def test_lightning_resume(trainer_method, registry, importing, tmp_path):
124128
# model name with random hash
125129
teamspace, org_team, model_name = _prepare_variables(f"resume_{trainer_method}")
126130
upload_model(model=checkpoint_path, name=f"{org_team}/{model_name}")
131+
expected_num_versions = 1
127132

128133
trainer_kwargs = {"model_registry": f"{org_team}/{model_name}"} if "<model>" not in registry else {}
129134
trainer = Trainer(
@@ -138,6 +143,8 @@ def test_lightning_resume(trainer_method, registry, importing, tmp_path):
138143
registry = registry.replace("<model>", f"{org_team}/{model_name}")
139144
if trainer_method == "fit":
140145
trainer.fit(BoringModel(), ckpt_path=registry)
146+
if trainer_kwargs:
147+
expected_num_versions += 1
141148
elif trainer_method == "validate":
142149
trainer.validate(BoringModel(), ckpt_path=registry)
143150
elif trainer_method == "test":
@@ -148,4 +155,41 @@ def test_lightning_resume(trainer_method, registry, importing, tmp_path):
148155
raise ValueError(f"Unknown trainer method: {trainer_method}")
149156

150157
# CLEANING
151-
_cleanup_model(teamspace, model_name)
158+
_cleanup_model(teamspace, model_name, expected_num_versions=expected_num_versions)
159+
160+
161+
@pytest.mark.parametrize(
162+
"importing",
163+
[
164+
pytest.param("lightning", marks=_SKIP_IF_LIGHTNING_BELLOW_2_5_1),
165+
pytest.param("pytorch_lightning", marks=_SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1),
166+
],
167+
)
168+
@pytest.mark.cloud()
169+
def test_lightning_checkpoint_ddp(importing, tmp_path):
170+
if importing == "lightning":
171+
from lightning import Trainer
172+
from lightning.pytorch.demos.boring_classes import BoringModel
173+
elif importing == "pytorch_lightning":
174+
from pytorch_lightning import Trainer
175+
from pytorch_lightning.demos.boring_classes import BoringModel
176+
177+
# model name with random hash
178+
teamspace, org_team, model_name = _prepare_variables("checkpoint_resume")
179+
trainer_args = {
180+
"default_root_dir": tmp_path,
181+
"accelerator": "cpu",
182+
"strategy": "ddp_spawn",
183+
"devices": 4,
184+
"model_registry": f"{org_team}/{model_name}",
185+
}
186+
187+
trainer = Trainer(max_epochs=2, **trainer_args)
188+
trainer.fit(BoringModel())
189+
190+
# FIXME: seems like barrier is not respected in the test, but in real life it correctly waits for all GPUs
191+
# trainer = Trainer(max_epochs=5, **trainer_args)
192+
# trainer.fit(BoringModel(), ckpt_path="registry")
193+
194+
# CLEANING
195+
_cleanup_model(teamspace, model_name, expected_num_versions=2)

0 commit comments

Comments
 (0)