@@ -61,7 +61,7 @@ def on_before_zero_grad(self, optimizer):
6161
6262 model = CurrentTestModel ()
6363
64- trainer = Trainer (default_root_dir = tmp_path , max_steps = max_steps , max_epochs = 2 )
64+ trainer = Trainer (devices = 1 , default_root_dir = tmp_path , max_steps = max_steps , max_epochs = 2 )
6565 assert model .on_before_zero_grad_called == 0
6666 trainer .fit (model )
6767 assert max_steps == model .on_before_zero_grad_called
@@ -406,7 +406,7 @@ def prepare_data(self): ...
406406@pytest .mark .parametrize (
407407 "kwargs" ,
408408 [
409- {},
409+ {"devices" : 1 },
410410 # these precision plugins modify the optimization flow, so testing them explicitly
411411 pytest .param ({"accelerator" : "gpu" , "devices" : 1 , "precision" : "16-mixed" }, marks = RunIf (min_cuda_gpus = 1 )),
412412 pytest .param (
@@ -528,6 +528,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path):
528528 # initial training to get a checkpoint
529529 model = BoringModel ()
530530 trainer = Trainer (
531+ devices = 1 ,
531532 default_root_dir = tmp_path ,
532533 max_epochs = 1 ,
533534 limit_train_batches = 2 ,
@@ -543,6 +544,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path):
543544 callback = HookedCallback (called )
544545 # already performed 1 step, resume and do 2 more
545546 trainer = Trainer (
547+ devices = 1 ,
546548 default_root_dir = tmp_path ,
547549 max_epochs = 2 ,
548550 limit_train_batches = 2 ,
@@ -605,6 +607,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
605607 # initial training to get a checkpoint
606608 model = BoringModel ()
607609 trainer = Trainer (
610+ devices = 1 ,
608611 default_root_dir = tmp_path ,
609612 max_steps = 1 ,
610613 limit_val_batches = 0 ,
@@ -624,6 +627,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
624627 train_batches = 2
625628 steps_after_reload = 1 + train_batches
626629 trainer = Trainer (
630+ devices = 1 ,
627631 default_root_dir = tmp_path ,
628632 max_steps = steps_after_reload ,
629633 limit_val_batches = 0 ,
@@ -690,6 +694,7 @@ def test_trainer_model_hook_system_eval(tmp_path, override_on_x_model_train, bat
690694 assert is_overridden (f"on_{ noun } _model_train" , model ) == override_on_x_model_train
691695 callback = HookedCallback (called )
692696 trainer = Trainer (
697+ devices = 1 ,
693698 default_root_dir = tmp_path ,
694699 max_epochs = 1 ,
695700 limit_val_batches = batches ,
@@ -731,7 +736,11 @@ def test_trainer_model_hook_system_predict(tmp_path):
731736 callback = HookedCallback (called )
732737 batches = 2
733738 trainer = Trainer (
734- default_root_dir = tmp_path , limit_predict_batches = batches , enable_progress_bar = False , callbacks = [callback ]
739+ devices = 1 ,
740+ default_root_dir = tmp_path ,
741+ limit_predict_batches = batches ,
742+ enable_progress_bar = False ,
743+ callbacks = [callback ],
735744 )
736745 trainer .predict (model )
737746 expected = [
@@ -797,7 +806,7 @@ def predict_dataloader(self):
797806
798807 model = CustomBoringModel ()
799808
800- trainer = Trainer (default_root_dir = tmp_path , fast_dev_run = 5 )
809+ trainer = Trainer (devices = 1 , default_root_dir = tmp_path , fast_dev_run = 5 )
801810
802811 trainer .fit (model )
803812 trainer .test (model )
@@ -812,6 +821,7 @@ def test_trainer_datamodule_hook_system(tmp_path):
812821 model = BoringModel ()
813822 batches = 2
814823 trainer = Trainer (
824+ devices = 1 ,
815825 default_root_dir = tmp_path ,
816826 max_epochs = 1 ,
817827 limit_train_batches = batches ,
@@ -887,7 +897,7 @@ class CustomHookedModel(HookedModel):
887897 assert is_overridden ("configure_model" , model ) == override_configure_model
888898
889899 datamodule = CustomHookedDataModule (ldm_called )
890- trainer = Trainer ()
900+ trainer = Trainer (devices = 1 )
891901 trainer .strategy .connect (model )
892902 trainer ._data_connector .attach_data (model , datamodule = datamodule )
893903 ckpt_path = str (tmp_path / "file.ckpt" )
@@ -960,6 +970,7 @@ def predict_step(self, *args, **kwargs):
960970
961971 model = MixedTrainModeModule ()
962972 trainer = Trainer (
973+ devices = 1 ,
963974 default_root_dir = tmp_path ,
964975 max_epochs = 1 ,
965976 val_check_interval = 1 ,
0 commit comments