@@ -96,16 +96,16 @@ def test_lightning_default_checkpointing(importing, tmp_path):
9696 _cleanup_model (teamspace , model_name )
9797
9898
99+ @pytest .mark .parametrize (
100+ "registry" , ["registry" , "registry:version:v1" , "registry:<model>" , "registry:<model>:version:v1" ]
101+ )
99102@pytest .mark .parametrize (
100103 "importing" ,
101104 [
102105 pytest .param ("lightning" , marks = _SKIP_IF_LIGHTNING_BELLOW_2_5_1 ),
103106 pytest .param ("pytorch_lightning" , marks = _SKIP_IF_PYTORCHLIGHTNING_BELLOW_2_5_1 ),
104107 ],
105108)
106- @pytest .mark .parametrize (
107- "registry" , ["registry" , "registry:version:v1" , "registry:<model>" , "registry:<model>:version:v1" ]
108- )
109109@pytest .mark .cloud ()
110110# todo: mock env variables as it would run in studio
111111def test_lightning_resume (importing , registry , tmp_path ):
@@ -124,7 +124,8 @@ def test_lightning_resume(importing, registry, tmp_path):
124124 teamspace , org_team , model_name = _prepare_variables ("resume" )
125125 upload_model (model = checkpoint_path , name = f"{ org_team } /{ model_name } " )
126126
127- trainer = Trainer (max_epochs = 2 , default_root_dir = tmp_path )
127+ trainer_kwargs = {"model_registry" : f"{ org_team } /{ model_name } " } if "<model>" not in registry else {}
128+ trainer = Trainer (max_epochs = 2 , default_root_dir = tmp_path , ** trainer_kwargs )
128129 registry = registry .replace ("<model>" , f"{ org_team } /{ model_name } " )
129130 trainer .fit (BoringModel (), ckpt_path = registry )
130131
0 commit comments