Skip to content

Commit 865176f

Browse files
committed
Fabric: add support for 'auto' accelerator
1 parent be608fa commit 865176f

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

src/lightning/fabric/cli.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
_CLICK_AVAILABLE = RequirementCache("click")
3737
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")
3838

39-
_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu")
39+
_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu", "auto")
4040

4141

4242
def _get_supported_strategies() -> list[str]:
@@ -208,6 +208,14 @@ def _set_env_variables(args: Namespace) -> None:
208208

209209
def _get_num_processes(accelerator: str, devices: str) -> int:
210210
"""Parse the `devices` argument to determine how many processes need to be launched on the current machine."""
211+
if accelerator == "auto":
212+
if torch.cuda.is_available():
213+
accelerator = "cuda"
214+
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
215+
accelerator = "mps"
216+
else:
217+
accelerator = "cpu"
218+
211219
if accelerator == "gpu":
212220
parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True)
213221
elif accelerator == "cuda":

0 commit comments

Comments
 (0)