Skip to content

Commit 513121d

Browse files
committed
Force hook standalone tests to single device
1 parent 5495204 commit 513121d

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

tests/tests_pytorch/models/test_hooks.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,12 @@ def on_before_zero_grad(self, optimizer):
6161

6262
model = CurrentTestModel()
6363

64-
trainer = Trainer(default_root_dir=tmp_path, max_steps=max_steps, max_epochs=2)
64+
trainer = Trainer(
65+
devices=1,
66+
default_root_dir=tmp_path,
67+
max_steps=max_steps,
68+
max_epochs=2
69+
)
6570
assert model.on_before_zero_grad_called == 0
6671
trainer.fit(model)
6772
assert max_steps == model.on_before_zero_grad_called
@@ -406,7 +411,7 @@ def prepare_data(self): ...
406411
@pytest.mark.parametrize(
407412
"kwargs",
408413
[
409-
{},
414+
{"devices": 1},
410415
# these precision plugins modify the optimization flow, so testing them explicitly
411416
pytest.param({"accelerator": "gpu", "devices": 1, "precision": "16-mixed"}, marks=RunIf(min_cuda_gpus=1)),
412417
pytest.param(
@@ -528,6 +533,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path):
528533
# initial training to get a checkpoint
529534
model = BoringModel()
530535
trainer = Trainer(
536+
devices=1,
531537
default_root_dir=tmp_path,
532538
max_epochs=1,
533539
limit_train_batches=2,
@@ -543,6 +549,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path):
543549
callback = HookedCallback(called)
544550
# already performed 1 step, resume and do 2 more
545551
trainer = Trainer(
552+
devices=1,
546553
default_root_dir=tmp_path,
547554
max_epochs=2,
548555
limit_train_batches=2,
@@ -605,6 +612,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
605612
# initial training to get a checkpoint
606613
model = BoringModel()
607614
trainer = Trainer(
615+
devices=1,
608616
default_root_dir=tmp_path,
609617
max_steps=1,
610618
limit_val_batches=0,
@@ -624,6 +632,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
624632
train_batches = 2
625633
steps_after_reload = 1 + train_batches
626634
trainer = Trainer(
635+
devices=1,
627636
default_root_dir=tmp_path,
628637
max_steps=steps_after_reload,
629638
limit_val_batches=0,
@@ -690,6 +699,7 @@ def test_trainer_model_hook_system_eval(tmp_path, override_on_x_model_train, bat
690699
assert is_overridden(f"on_{noun}_model_train", model) == override_on_x_model_train
691700
callback = HookedCallback(called)
692701
trainer = Trainer(
702+
devices=1,
693703
default_root_dir=tmp_path,
694704
max_epochs=1,
695705
limit_val_batches=batches,
@@ -731,7 +741,10 @@ def test_trainer_model_hook_system_predict(tmp_path):
731741
callback = HookedCallback(called)
732742
batches = 2
733743
trainer = Trainer(
734-
default_root_dir=tmp_path, limit_predict_batches=batches, enable_progress_bar=False, callbacks=[callback]
744+
devices=1,
745+
default_root_dir=tmp_path,
746+
limit_predict_batches=batches,
747+
enable_progress_bar=False, callbacks=[callback]
735748
)
736749
trainer.predict(model)
737750
expected = [
@@ -797,7 +810,7 @@ def predict_dataloader(self):
797810

798811
model = CustomBoringModel()
799812

800-
trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=5)
813+
trainer = Trainer(devices=1, default_root_dir=tmp_path, fast_dev_run=5)
801814

802815
trainer.fit(model)
803816
trainer.test(model)
@@ -812,6 +825,7 @@ def test_trainer_datamodule_hook_system(tmp_path):
812825
model = BoringModel()
813826
batches = 2
814827
trainer = Trainer(
828+
devices=1,
815829
default_root_dir=tmp_path,
816830
max_epochs=1,
817831
limit_train_batches=batches,
@@ -887,7 +901,7 @@ class CustomHookedModel(HookedModel):
887901
assert is_overridden("configure_model", model) == override_configure_model
888902

889903
datamodule = CustomHookedDataModule(ldm_called)
890-
trainer = Trainer()
904+
trainer = Trainer(devices=1)
891905
trainer.strategy.connect(model)
892906
trainer._data_connector.attach_data(model, datamodule=datamodule)
893907
ckpt_path = str(tmp_path / "file.ckpt")
@@ -960,6 +974,7 @@ def predict_step(self, *args, **kwargs):
960974

961975
model = MixedTrainModeModule()
962976
trainer = Trainer(
977+
devices=1,
963978
default_root_dir=tmp_path,
964979
max_epochs=1,
965980
val_check_interval=1,

0 commit comments

Comments
 (0)