diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index 2268614abb97b..1e543a17a731f 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -34,7 +34,7 @@ _CLICK_AVAILABLE = RequirementCache("click") _LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk") -_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu") +_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu", "auto") def _get_supported_strategies() -> list[str]: @@ -187,6 +187,14 @@ def _set_env_variables(args: Namespace) -> None: def _get_num_processes(accelerator: str, devices: str) -> int: """Parse the `devices` argument to determine how many processes need to be launched on the current machine.""" + if accelerator == "auto": + if torch.cuda.is_available(): + accelerator = "cuda" + elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): + accelerator = "mps" + else: + accelerator = "cpu" + if accelerator == "gpu": parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True) elif accelerator == "cuda":