1414LIT_TEAMSPACE = "LitModels"
1515
1616
17+ def _prepare_variables (test_name : str ) -> tuple [Teamspace , str , str ]:
18+ model_name = f"litmodels_test_integrations_{ test_name } +{ os .urandom (8 ).hex ()} "
19+ teamspace = _resolve_teamspace (org = LIT_ORG , teamspace = LIT_TEAMSPACE , user = None )
20+ org_team = f"{ teamspace .owner .name } /{ teamspace .name } "
21+ return teamspace , org_team , model_name
22+
23+
1724def _cleanup_model (teamspace : Teamspace , model_name : str ) -> None :
1825 """Cleanup model from the teamspace."""
1926 client = GridRestClient ()
@@ -35,9 +42,7 @@ def test_upload_download_model(tmp_path):
3542 f .write ("dummy" )
3643
3744 # model name with random hash
38- model_name = f"litmodels_test_integrations+{ os .urandom (8 ).hex ()} "
39- teamspace = _resolve_teamspace (org = LIT_ORG , teamspace = LIT_TEAMSPACE , user = None )
40- org_team = f"{ teamspace .owner .name } /{ teamspace .name } "
45+ teamspace , org_team , model_name = _prepare_variables ("upload_download" )
4146
4247 out = StringIO ()
4348 with redirect_stdout (out ):
@@ -78,9 +83,7 @@ def test_lightning_default_checkpointing(importing, tmp_path):
7883 from pytorch_lightning .demos .boring_classes import BoringModel
7984
8085 # model name with random hash
81- model_name = f"litmodels_test_integrations+{ os .urandom (8 ).hex ()} "
82- teamspace = _resolve_teamspace (org = LIT_ORG , teamspace = LIT_TEAMSPACE , user = None )
83- org_team = f"{ teamspace .owner .name } /{ teamspace .name } "
86+ teamspace , org_team , model_name = _prepare_variables ("default_checkpoint" )
8487
8588 trainer = Trainer (
8689 max_epochs = 2 ,
@@ -91,3 +94,39 @@ def test_lightning_default_checkpointing(importing, tmp_path):
9194
9295 # CLEANING
9396 _cleanup_model (teamspace , model_name )
97+
98+
99+ @pytest .mark .parametrize (
100+ "importing" ,
101+ [
102+ pytest .param ("lightning" , marks = _SKIP_IF_LIGHTNING_BELLOW_2_5_1 ),
103+ pytest .param ("pytorch_lightning" , marks = _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1 ),
104+ ],
105+ )
106+ @pytest .mark .parametrize (
107+ "registry" , ["registry" , "registry:version:v1" , "registry:<model>" , "registry:<model>:version:v1" ]
108+ )
109+ @pytest .mark .cloud ()
110+ # todo: mock env variables as it would run in studio
111+ def test_lightning_resume (importing , registry , tmp_path ):
112+ if importing == "lightning" :
113+ from lightning import Trainer
114+ from lightning .pytorch .demos .boring_classes import BoringModel
115+ elif importing == "pytorch_lightning" :
116+ from pytorch_lightning import Trainer
117+ from pytorch_lightning .demos .boring_classes import BoringModel
118+
119+ trainer = Trainer (max_epochs = 1 , default_root_dir = tmp_path )
120+ trainer .fit (BoringModel ())
121+ checkpoint_path = getattr (trainer .checkpoint_callback , "best_model_path" )
122+
123+ # model name with random hash
124+ teamspace , org_team , model_name = _prepare_variables ("resume" )
125+ upload_model (model = checkpoint_path , name = f"{ org_team } /{ model_name } " )
126+
127+ trainer = Trainer (max_epochs = 2 , default_root_dir = tmp_path )
128+ registry = registry .replace ("<model>" , f"{ org_team } /{ model_name } " )
129+ trainer .fit (BoringModel (), ckpt_path = registry )
130+
131+ # CLEANING
132+ _cleanup_model (teamspace , model_name )
0 commit comments