Skip to content

Commit 951c0bc

Browse files
committed
add auto option for accelerator and device
1 parent c03660a commit 951c0bc

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

src/lightning/fabric/cli.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
_CLICK_AVAILABLE = RequirementCache("click")
3535
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")
3636

37-
_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu")
37+
_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu", "auto")
3838

3939

4040
def _get_supported_strategies() -> list[str]:
@@ -187,6 +187,8 @@ def _set_env_variables(args: Namespace) -> None:
187187

188188
def _get_num_processes(accelerator: str, devices: str) -> int:
189189
"""Parse the `devices` argument to determine how many processes need to be launched on the current machine."""
190+
if devices == "auto":
191+
devices = "1" # default to 1 device if 'auto' is specified
190192
if accelerator == "gpu":
191193
parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True)
192194
elif accelerator == "cuda":
@@ -195,9 +197,20 @@ def _get_num_processes(accelerator: str, devices: str) -> int:
195197
parsed_devices = MPSAccelerator.parse_devices(devices)
196198
elif accelerator == "tpu":
197199
raise ValueError("Launching processes for TPU through the CLI is not supported.")
200+
elif accelerator == "auto" or accelerator is None:
201+
if torch.cuda.is_available():
202+
parsed_devices = CUDAAccelerator.parse_devices(devices)
203+
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
204+
parsed_devices = MPSAccelerator.parse_devices(devices)
205+
else:
206+
parsed_devices = CPUAccelerator.parse_devices(devices)
198207
else:
199208
return CPUAccelerator.parse_devices(devices)
200-
return len(parsed_devices) if parsed_devices is not None else 0
209+
return (
210+
len(parsed_devices)
211+
if isinstance(parsed_devices, list)
212+
else (parsed_devices if isinstance(parsed_devices, int) else 0)
213+
)
201214

202215

203216
def _torchrun_launch(args: Namespace, script_args: list[str]) -> None:

tests/tests_fabric/test_cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_run_env_vars_defaults(monkeypatch, fake_script):
4646
assert "LT_PRECISION" not in os.environ
4747

4848

49-
@pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))])
49+
@pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", "auto", pytest.param("mps", marks=RunIf(mps=True))])
5050
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
5151
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
5252
def test_run_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
@@ -85,7 +85,7 @@ def test_run_env_vars_unsupported_strategy(strategy, fake_script):
8585
assert f"Invalid value for '--strategy': '{strategy}'" in ioerr.getvalue()
8686

8787

88-
@pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1"])
88+
@pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1", "auto"])
8989
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
9090
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
9191
def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
@@ -97,7 +97,7 @@ def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
9797

9898

9999
@RunIf(mps=True)
100-
@pytest.mark.parametrize("accelerator", ["mps", "gpu"])
100+
@pytest.mark.parametrize("accelerator", ["mps", "gpu", "auto"])
101101
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
102102
def test_run_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
103103
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())

0 commit comments

Comments
 (0)