@@ -61,7 +61,12 @@ def on_before_zero_grad(self, optimizer):
6161
6262 model = CurrentTestModel ()
6363
64- trainer = Trainer (default_root_dir = tmp_path , max_steps = max_steps , max_epochs = 2 )
64+ trainer = Trainer (
65+ devices = 1 ,
66+ default_root_dir = tmp_path ,
67+ max_steps = max_steps ,
68+ max_epochs = 2
69+ )
6570 assert model .on_before_zero_grad_called == 0
6671 trainer .fit (model )
6772 assert max_steps == model .on_before_zero_grad_called
@@ -406,7 +411,7 @@ def prepare_data(self): ...
406411@pytest .mark .parametrize (
407412 "kwargs" ,
408413 [
409- {},
414+ {"devices" : 1 },
410415 # these precision plugins modify the optimization flow, so testing them explicitly
411416 pytest .param ({"accelerator" : "gpu" , "devices" : 1 , "precision" : "16-mixed" }, marks = RunIf (min_cuda_gpus = 1 )),
412417 pytest .param (
@@ -528,6 +533,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path):
528533 # initial training to get a checkpoint
529534 model = BoringModel ()
530535 trainer = Trainer (
536+ devices = 1 ,
531537 default_root_dir = tmp_path ,
532538 max_epochs = 1 ,
533539 limit_train_batches = 2 ,
@@ -543,6 +549,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path):
543549 callback = HookedCallback (called )
544550 # already performed 1 step, resume and do 2 more
545551 trainer = Trainer (
552+ devices = 1 ,
546553 default_root_dir = tmp_path ,
547554 max_epochs = 2 ,
548555 limit_train_batches = 2 ,
@@ -605,6 +612,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
605612 # initial training to get a checkpoint
606613 model = BoringModel ()
607614 trainer = Trainer (
615+ devices = 1 ,
608616 default_root_dir = tmp_path ,
609617 max_steps = 1 ,
610618 limit_val_batches = 0 ,
@@ -624,6 +632,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
624632 train_batches = 2
625633 steps_after_reload = 1 + train_batches
626634 trainer = Trainer (
635+ devices = 1 ,
627636 default_root_dir = tmp_path ,
628637 max_steps = steps_after_reload ,
629638 limit_val_batches = 0 ,
@@ -690,6 +699,7 @@ def test_trainer_model_hook_system_eval(tmp_path, override_on_x_model_train, bat
690699 assert is_overridden (f"on_{ noun } _model_train" , model ) == override_on_x_model_train
691700 callback = HookedCallback (called )
692701 trainer = Trainer (
702+ devices = 1 ,
693703 default_root_dir = tmp_path ,
694704 max_epochs = 1 ,
695705 limit_val_batches = batches ,
@@ -731,7 +741,10 @@ def test_trainer_model_hook_system_predict(tmp_path):
731741 callback = HookedCallback (called )
732742 batches = 2
733743 trainer = Trainer (
734- default_root_dir = tmp_path , limit_predict_batches = batches , enable_progress_bar = False , callbacks = [callback ]
744+ devices = 1 ,
745+ default_root_dir = tmp_path ,
746+ limit_predict_batches = batches ,
747+ enable_progress_bar = False , callbacks = [callback ]
735748 )
736749 trainer .predict (model )
737750 expected = [
@@ -797,7 +810,7 @@ def predict_dataloader(self):
797810
798811 model = CustomBoringModel ()
799812
800- trainer = Trainer (default_root_dir = tmp_path , fast_dev_run = 5 )
813+ trainer = Trainer (devices = 1 , default_root_dir = tmp_path , fast_dev_run = 5 )
801814
802815 trainer .fit (model )
803816 trainer .test (model )
@@ -812,6 +825,7 @@ def test_trainer_datamodule_hook_system(tmp_path):
812825 model = BoringModel ()
813826 batches = 2
814827 trainer = Trainer (
828+ devices = 1 ,
815829 default_root_dir = tmp_path ,
816830 max_epochs = 1 ,
817831 limit_train_batches = batches ,
@@ -887,7 +901,7 @@ class CustomHookedModel(HookedModel):
887901 assert is_overridden ("configure_model" , model ) == override_configure_model
888902
889903 datamodule = CustomHookedDataModule (ldm_called )
890- trainer = Trainer ()
904+ trainer = Trainer (devices = 1 )
891905 trainer .strategy .connect (model )
892906 trainer ._data_connector .attach_data (model , datamodule = datamodule )
893907 ckpt_path = str (tmp_path / "file.ckpt" )
@@ -960,6 +974,7 @@ def predict_step(self, *args, **kwargs):
960974
961975 model = MixedTrainModeModule ()
962976 trainer = Trainer (
977+ devices = 1 ,
963978 default_root_dir = tmp_path ,
964979 max_epochs = 1 ,
965980 val_check_interval = 1 ,
0 commit comments