@@ -141,26 +141,38 @@ def test_lightning_default_checkpointing(importing, in_studio, monkeypatch, tmp_
141141 pytest .param ("pytorch_lightning" , marks = _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1 ),
142142 ],
143143)
144+ @pytest .mark .parametrize (
145+ "in_studio" ,
146+ [
147+ False ,
148+ pytest .param (True , marks = pytest .mark .skipif (platform .system () == "Windows" , reason = "studio is not Windows" )),
149+ ],
150+ )
144151@pytest .mark .cloud
145- # todo: mock env variables as it would run in studio
146- def test_lightning_plain_resume (trainer_method , registry , importing , tmp_path ):
152+ def test_lightning_plain_resume (trainer_method , registry , importing , in_studio , tmp_path , monkeypatch ):
147153 if importing == "lightning" :
148154 from lightning import Trainer
149155 from lightning .pytorch .demos .boring_classes import BoringModel
150156 elif importing == "pytorch_lightning" :
151157 from pytorch_lightning import Trainer
152158 from pytorch_lightning .demos .boring_classes import BoringModel
153159
160+ if in_studio :
161+ # mock env variables as it would run in studio
162+ monkeypatch .setenv ("LIGHTNING_ORG" , LIT_ORG )
163+ monkeypatch .setenv ("LIGHTNING_TEAMSPACE" , LIT_TEAMSPACE )
164+
154165 trainer = Trainer (max_epochs = 1 , limit_train_batches = 50 , limit_val_batches = 20 , default_root_dir = tmp_path )
155166 trainer .fit (BoringModel ())
156167 checkpoint_path = getattr (trainer .checkpoint_callback , "best_model_path" )
157168
158169 # model name with random hash
159170 teamspace , org_team , model_name = _prepare_variables (f"resume_{ trainer_method } " )
160- upload_model (model = checkpoint_path , name = f"{ org_team } /{ model_name } " )
171+ model_registry = f"{ org_team } /{ model_name } " if not in_studio else model_name
172+ upload_model (model = checkpoint_path , name = model_registry )
161173 expected_num_versions = 1
162174
163- trainer_kwargs = {"model_registry" : f" { org_team } / { model_name } " } if "<model>" not in registry else {}
175+ trainer_kwargs = {"model_registry" : model_registry } if "<model>" not in registry else {}
164176 trainer = Trainer (
165177 max_epochs = 2 ,
166178 default_root_dir = tmp_path ,
@@ -170,7 +182,7 @@ def test_lightning_plain_resume(trainer_method, registry, importing, tmp_path):
170182 limit_predict_batches = 10 ,
171183 ** trainer_kwargs ,
172184 )
173- registry = registry .replace ("<model>" , f" { org_team } / { model_name } " )
185+ registry = registry .replace ("<model>" , model_registry )
174186 if trainer_method == "fit" :
175187 trainer .fit (BoringModel (), ckpt_path = registry )
176188 if trainer_kwargs :
0 commit comments