@@ -96,6 +96,7 @@ def test_lightning_default_checkpointing(importing, tmp_path):
9696 _cleanup_model (teamspace , model_name )
9797
9898
99+ @pytest .mark .parametrize ("trainer_method" , ["fit" , "validate" , "test" , "predict" ])
99100@pytest .mark .parametrize (
100101 "registry" , ["registry" , "registry:version:v1" , "registry:<model>" , "registry:<model>:version:v1" ]
101102)
@@ -108,15 +109,15 @@ def test_lightning_default_checkpointing(importing, tmp_path):
108109)
109110@pytest .mark .cloud ()
110111# todo: mock env variables as it would run in studio
111- def test_lightning_resume (importing , registry , tmp_path ):
112+ def test_lightning_resume (trainer_method , registry , importing , tmp_path ):
112113 if importing == "lightning" :
113114 from lightning import Trainer
114115 from lightning .pytorch .demos .boring_classes import BoringModel
115116 elif importing == "pytorch_lightning" :
116117 from pytorch_lightning import Trainer
117118 from pytorch_lightning .demos .boring_classes import BoringModel
118119
119- trainer = Trainer (max_epochs = 1 , default_root_dir = tmp_path )
120+ trainer = Trainer (max_epochs = 1 , limit_train_batches = 50 , limit_val_batches = 20 , default_root_dir = tmp_path )
120121 trainer .fit (BoringModel ())
121122 checkpoint_path = getattr (trainer .checkpoint_callback , "best_model_path" )
122123
@@ -125,9 +126,26 @@ def test_lightning_resume(importing, registry, tmp_path):
125126 upload_model (model = checkpoint_path , name = f"{ org_team } /{ model_name } " )
126127
127128 trainer_kwargs = {"model_registry" : f"{ org_team } /{ model_name } " } if "<model>" not in registry else {}
128- trainer = Trainer (max_epochs = 2 , default_root_dir = tmp_path , ** trainer_kwargs )
129+ trainer = Trainer (
130+ max_epochs = 2 ,
131+ default_root_dir = tmp_path ,
132+ limit_train_batches = 10 ,
133+ limit_val_batches = 10 ,
134+ limit_test_batches = 10 ,
135+ limit_predict_batches = 10 ,
136+ ** trainer_kwargs ,
137+ )
129138 registry = registry .replace ("<model>" , f"{ org_team } /{ model_name } " )
130- trainer .fit (BoringModel (), ckpt_path = registry )
139+ if trainer_method == "fit" :
140+ trainer .fit (BoringModel (), ckpt_path = registry )
141+ elif trainer_method == "validate" :
142+ trainer .validate (BoringModel (), ckpt_path = registry )
143+ elif trainer_method == "test" :
144+ trainer .test (BoringModel (), ckpt_path = registry )
145+ elif trainer_method == "predict" :
146+ trainer .predict (BoringModel (), ckpt_path = registry )
147+ else :
148+ raise ValueError (f"Unknown trainer method: { trainer_method } " )
131149
132150 # CLEANING
133151 _cleanup_model (teamspace , model_name )
0 commit comments