11import os
22from contextlib import redirect_stdout
33from io import StringIO
4+ from typing import Optional
45
56import pytest
67from lightning_sdk import Teamspace
1516
1617
1718def _prepare_variables (test_name : str ) -> tuple [Teamspace , str , str ]:
18- model_name = f"litmodels_test_integrations_ { test_name } +{ os .urandom (8 ).hex ()} "
19+ model_name = f"ci-test_integrations_ { test_name } +{ os .urandom (8 ).hex ()} "
1920 teamspace = _resolve_teamspace (org = LIT_ORG , teamspace = LIT_TEAMSPACE , user = None )
2021 org_team = f"{ teamspace .owner .name } /{ teamspace .name } "
2122 return teamspace , org_team , model_name
2223
2324
24- def _cleanup_model (teamspace : Teamspace , model_name : str ) -> None :
25+ def _cleanup_model (teamspace : Teamspace , model_name : str , expected_num_versions : Optional [ int ] = None ) -> None :
2526 """Cleanup model from the teamspace."""
2627 client = GridRestClient ()
2728 # cleaning created models as each test run shall have unique model name
@@ -30,7 +31,10 @@ def _cleanup_model(teamspace: Teamspace, model_name: str) -> None:
3031 project_name = teamspace .name ,
3132 model_name = model_name ,
3233 )
33- client .models_store_delete_model (project_id = teamspace .id , model_id = model .id )
34+ if expected_num_versions is not None :
35+ versions = client .models_store_list_model_versions (project_id = model .project_id , model_id = model .id )
36+ assert expected_num_versions == len (versions .versions )
37+ client .models_store_delete_model (project_id = model .project_id , model_id = model .id )
3438
3539
3640@pytest .mark .cloud ()
@@ -62,7 +66,7 @@ def test_upload_download_model(tmp_path):
6266 assert os .path .isfile (os .path .join (tmp_path , file ))
6367
6468 # CLEANING
65- _cleanup_model (teamspace , model_name )
69+ _cleanup_model (teamspace , model_name , expected_num_versions = 1 )
6670
6771
6872@pytest .mark .parametrize (
@@ -93,7 +97,7 @@ def test_lightning_default_checkpointing(importing, tmp_path):
9397 trainer .fit (BoringModel ())
9498
9599 # CLEANING
96- _cleanup_model (teamspace , model_name )
100+ _cleanup_model (teamspace , model_name , expected_num_versions = 2 )
97101
98102
99103@pytest .mark .parametrize ("trainer_method" , ["fit" , "validate" , "test" , "predict" ])
@@ -109,7 +113,7 @@ def test_lightning_default_checkpointing(importing, tmp_path):
109113)
110114@pytest .mark .cloud ()
111115# todo: mock env variables as it would run in studio
112- def test_lightning_resume (trainer_method , registry , importing , tmp_path ):
116+ def test_lightning_plain_resume (trainer_method , registry , importing , tmp_path ):
113117 if importing == "lightning" :
114118 from lightning import Trainer
115119 from lightning .pytorch .demos .boring_classes import BoringModel
@@ -124,6 +128,7 @@ def test_lightning_resume(trainer_method, registry, importing, tmp_path):
124128 # model name with random hash
125129 teamspace , org_team , model_name = _prepare_variables (f"resume_{ trainer_method } " )
126130 upload_model (model = checkpoint_path , name = f"{ org_team } /{ model_name } " )
131+ expected_num_versions = 1
127132
128133 trainer_kwargs = {"model_registry" : f"{ org_team } /{ model_name } " } if "<model>" not in registry else {}
129134 trainer = Trainer (
@@ -138,6 +143,8 @@ def test_lightning_resume(trainer_method, registry, importing, tmp_path):
138143 registry = registry .replace ("<model>" , f"{ org_team } /{ model_name } " )
139144 if trainer_method == "fit" :
140145 trainer .fit (BoringModel (), ckpt_path = registry )
146+ if trainer_kwargs :
147+ expected_num_versions += 1
141148 elif trainer_method == "validate" :
142149 trainer .validate (BoringModel (), ckpt_path = registry )
143150 elif trainer_method == "test" :
@@ -148,4 +155,41 @@ def test_lightning_resume(trainer_method, registry, importing, tmp_path):
148155 raise ValueError (f"Unknown trainer method: { trainer_method } " )
149156
150157 # CLEANING
151- _cleanup_model (teamspace , model_name )
158+ _cleanup_model (teamspace , model_name , expected_num_versions = expected_num_versions )
159+
160+
161+ @pytest .mark .parametrize (
162+ "importing" ,
163+ [
164+ pytest .param ("lightning" , marks = _SKIP_IF_LIGHTNING_BELLOW_2_5_1 ),
165+ pytest .param ("pytorch_lightning" , marks = _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1 ),
166+ ],
167+ )
168+ @pytest .mark .cloud ()
169+ def test_lightning_checkpoint_ddp (importing , tmp_path ):
170+ if importing == "lightning" :
171+ from lightning import Trainer
172+ from lightning .pytorch .demos .boring_classes import BoringModel
173+ elif importing == "pytorch_lightning" :
174+ from pytorch_lightning import Trainer
175+ from pytorch_lightning .demos .boring_classes import BoringModel
176+
177+ # model name with random hash
178+ teamspace , org_team , model_name = _prepare_variables ("checkpoint_resume" )
179+ trainer_args = {
180+ "default_root_dir" : tmp_path ,
181+ "accelerator" : "cpu" ,
182+ "strategy" : "ddp_spawn" ,
183+ "devices" : 4 ,
184+ "model_registry" : f"{ org_team } /{ model_name } " ,
185+ }
186+
187+ trainer = Trainer (max_epochs = 2 , ** trainer_args )
188+ trainer .fit (BoringModel ())
189+
190+ # FIXME: seems like barrier is not respected in the test, but in real life it correctly waits for all GPUs
191+ # trainer = Trainer(max_epochs=5, **trainer_args)
192+ # trainer.fit(BoringModel(), ckpt_path="registry")
193+
194+ # CLEANING
195+ _cleanup_model (teamspace , model_name , expected_num_versions = 2 )
0 commit comments