Skip to content

Commit 83e0c4a

Browse files
kaushikb11lexierule
authored andcommitted
Raise MisconfigurationException when the accelerator is available but… (#12708)
1 parent ba1e869 commit 83e0c4a

File tree

4 files changed

+20
-11
lines changed

4 files changed

+20
-11
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9191
- Don't raise a warning when `nn.Module` is not saved under hparams ([#12669](https://github.com/PyTorchLightning/pytorch-lightning/pull/12669))
9292

9393

94-
-
94+
- Raise `MisconfigurationException` when the accelerator is available but the user passes invalid `([]/0/"0")` values to the `devices` flag ([#12708](https://github.com/PyTorchLightning/pytorch-lightning/pull/12708))
9595

9696

9797
## [1.6.0] - 2022-03-29

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,17 @@ def _check_device_config_and_set_final_flags(
413413
self._num_nodes_flag = int(num_nodes) if num_nodes is not None else 1
414414
self._devices_flag = devices
415415

416+
if self._devices_flag in ([], 0, "0"):
417+
accelerator_name = (
418+
self._accelerator_flag.__class__.__qualname__
419+
if isinstance(self._accelerator_flag, Accelerator)
420+
else self._accelerator_flag
421+
)
422+
raise MisconfigurationException(
423+
f"`Trainer(devices={self._devices_flag!r})` value is not a valid input"
424+
f" using {accelerator_name} accelerator."
425+
)
426+
416427
# TODO: Delete this method when num_processes, gpus, ipus and tpu_cores gets removed
417428
self._map_deprecated_devices_specfic_info_to_accelerator_and_device_flag(
418429
devices, num_processes, gpus, ipus, tpu_cores

tests/accelerators/test_accelerator_connector.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -505,15 +505,6 @@ def test_accelerator_cpu(_):
505505
trainer = Trainer(accelerator="cpu", gpus=1)
506506

507507

508-
@mock.patch("torch.cuda.is_available", return_value=False)
509-
@pytest.mark.parametrize("devices", ["0", 0, []])
510-
def test_passing_zero_and_empty_list_to_devices_flag(_, devices):
511-
with pytest.raises(
512-
MisconfigurationException, match="can not run on your system since the accelerator is not available."
513-
):
514-
Trainer(accelerator="gpu", devices=devices)
515-
516-
517508
@RunIf(min_gpus=1)
518509
def test_accelerator_gpu():
519510
trainer = Trainer(accelerator="gpu", devices=1)
@@ -1015,3 +1006,10 @@ def __init__(self, **kwargs):
10151006
def test_plugin_only_one_instance_for_one_type(plugins, expected):
10161007
with pytest.raises(MisconfigurationException, match=f"Received multiple values for {expected}"):
10171008
Trainer(plugins=plugins)
1009+
1010+
1011+
@pytest.mark.parametrize("accelerator", ("cpu", "gpu", "tpu", "ipu"))
1012+
@pytest.mark.parametrize("devices", ("0", 0, []))
1013+
def test_passing_zero_and_empty_list_to_devices_flag(accelerator, devices):
1014+
with pytest.raises(MisconfigurationException, match="value is not a valid input using"):
1015+
Trainer(accelerator=accelerator, devices=devices)

tests/accelerators/test_cpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
6969
func(model, ckpt_path=checkpoint_path)
7070

7171

72-
@pytest.mark.parametrize("devices", ([3], -1, 0))
72+
@pytest.mark.parametrize("devices", ([3], -1))
7373
def test_invalid_devices_with_cpu_accelerator(devices):
7474
"""Test invalid device flag raises MisconfigurationException with CPUAccelerator."""
7575
with pytest.raises(MisconfigurationException, match="should be an int > 0"):

0 commit comments

Comments
 (0)