Skip to content

Commit 6e90049

Browse files
fnhirwaBorda
andauthored
Fabric: Enable "auto" for devices and accelerator as cli arguments (#20913)
* add auto option for accelerator and device * use auto detect accelerator and add auto to the unit tests --------- Co-authored-by: Jirka Borovec <[email protected]>
1 parent 3138305 commit 6e90049

File tree

5 files changed

+37
-17
lines changed

5 files changed

+37
-17
lines changed

src/lightning/fabric/cli.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS
2626
from lightning.fabric.strategies import STRATEGY_REGISTRY
2727
from lightning.fabric.utilities.consolidate_checkpoint import _process_cli_args
28-
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
28+
from lightning.fabric.utilities.device_parser import _parse_gpu_ids, _select_auto_accelerator
2929
from lightning.fabric.utilities.distributed import _suggested_max_num_threads
3030
from lightning.fabric.utilities.load import _load_distributed_checkpoint
3131

@@ -34,7 +34,7 @@
3434
_CLICK_AVAILABLE = RequirementCache("click")
3535
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")
3636

37-
_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu")
37+
_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu", "auto")
3838

3939

4040
def _get_supported_strategies() -> list[str]:
@@ -187,6 +187,14 @@ def _set_env_variables(args: Namespace) -> None:
187187

188188
def _get_num_processes(accelerator: str, devices: str) -> int:
189189
"""Parse the `devices` argument to determine how many processes need to be launched on the current machine."""
190+
191+
if accelerator == "auto" or accelerator is None:
192+
accelerator = _select_auto_accelerator()
193+
if devices == "auto":
194+
if accelerator == "cuda" or accelerator == "mps" or accelerator == "cpu":
195+
devices = "1"
196+
else:
197+
raise ValueError(f"Cannot default to '1' device for accelerator='{accelerator}'")
190198
if accelerator == "gpu":
191199
parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True)
192200
elif accelerator == "cuda":

src/lightning/fabric/utilities/device_parser.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,18 @@ def _check_data_type(device_ids: object) -> None:
204204
raise TypeError(f"{msg} a sequence of {type(id_).__name__}.")
205205
elif type(device_ids) not in (int, str):
206206
raise TypeError(f"{msg} {device_ids!r}.")
207+
208+
209+
def _select_auto_accelerator() -> str:
210+
"""Choose the accelerator type (str) based on availability."""
211+
from lightning.fabric.accelerators.cuda import CUDAAccelerator
212+
from lightning.fabric.accelerators.mps import MPSAccelerator
213+
from lightning.fabric.accelerators.xla import XLAAccelerator
214+
215+
if XLAAccelerator.is_available():
216+
return "tpu"
217+
if MPSAccelerator.is_available():
218+
return "mps"
219+
if CUDAAccelerator.is_available():
220+
return "cuda"
221+
return "cpu"

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
SLURMEnvironment,
3030
TorchElasticEnvironment,
3131
)
32-
from lightning.fabric.utilities.device_parser import _determine_root_gpu_device
32+
from lightning.fabric.utilities.device_parser import _determine_root_gpu_device, _select_auto_accelerator
3333
from lightning.fabric.utilities.imports import _IS_INTERACTIVE
3434
from lightning.pytorch.accelerators import AcceleratorRegistry
3535
from lightning.pytorch.accelerators.accelerator import Accelerator
@@ -332,18 +332,12 @@ def _check_device_config_and_set_final_flags(self, devices: Union[list[int], str
332332
@staticmethod
333333
def _choose_auto_accelerator() -> str:
334334
"""Choose the accelerator type (str) based on availability."""
335-
if XLAAccelerator.is_available():
336-
return "tpu"
337335
if _habana_available_and_importable():
338336
from lightning_habana import HPUAccelerator
339337

340338
if HPUAccelerator.is_available():
341339
return "hpu"
342-
if MPSAccelerator.is_available():
343-
return "mps"
344-
if CUDAAccelerator.is_available():
345-
return "cuda"
346-
return "cpu"
340+
return _select_auto_accelerator()
347341

348342
@staticmethod
349343
def _choose_gpu_accelerator_backend() -> str:

tests/tests_fabric/test_cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_run_env_vars_defaults(monkeypatch, fake_script):
4646
assert "LT_PRECISION" not in os.environ
4747

4848

49-
@pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))])
49+
@pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", "auto", pytest.param("mps", marks=RunIf(mps=True))])
5050
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
5151
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
5252
def test_run_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
@@ -85,7 +85,7 @@ def test_run_env_vars_unsupported_strategy(strategy, fake_script):
8585
assert f"Invalid value for '--strategy': '{strategy}'" in ioerr.getvalue()
8686

8787

88-
@pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1"])
88+
@pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1", "auto"])
8989
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
9090
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
9191
def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
@@ -97,7 +97,7 @@ def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
9797

9898

9999
@RunIf(mps=True)
100-
@pytest.mark.parametrize("accelerator", ["mps", "gpu"])
100+
@pytest.mark.parametrize("accelerator", ["mps", "gpu", "auto"])
101101
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
102102
def test_run_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
103103
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -491,13 +491,15 @@ def test_strategy_choice_ddp_torchelastic(_, __, mps_count_0, cuda_count_2):
491491
"LOCAL_RANK": "1",
492492
},
493493
)
494-
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
495-
@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
496-
def test_torchelastic_priority_over_slurm(*_):
494+
def test_torchelastic_priority_over_slurm(monkeypatch):
497495
"""Test that the TorchElastic cluster environment is chosen over SLURM when both are detected."""
496+
with monkeypatch.context():
497+
mock_cuda_count(monkeypatch, 2)
498+
mock_mps_count(monkeypatch, 0)
499+
mock_hpu_count(monkeypatch, 0)
500+
connector = _AcceleratorConnector(strategy="ddp")
498501
assert TorchElasticEnvironment.detect()
499502
assert SLURMEnvironment.detect()
500-
connector = _AcceleratorConnector(strategy="ddp")
501503
assert isinstance(connector.strategy.cluster_environment, TorchElasticEnvironment)
502504

503505

@@ -1003,6 +1005,7 @@ def _mock_tpu_available(value):
10031005
with monkeypatch.context():
10041006
mock_cuda_count(monkeypatch, 2)
10051007
mock_mps_count(monkeypatch, 0)
1008+
mock_hpu_count(monkeypatch, 0)
10061009
_mock_tpu_available(True)
10071010
connector = _AcceleratorConnector()
10081011
assert isinstance(connector.accelerator, XLAAccelerator)

0 commit comments

Comments
 (0)