Skip to content

Commit ad60f75

Browse files
committed
fix device defaulting for macos
1 parent 8910453 commit ad60f75

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

src/lightning/fabric/cli.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,19 @@ def _set_env_variables(args: Namespace) -> None:
187187

188188
def _get_num_processes(accelerator: str, devices: str) -> int:
189189
"""Parse the `devices` argument to determine how many processes need to be launched on the current machine."""
190+
if accelerator == "auto" or accelerator is None:
191+
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
192+
accelerator = "cuda"
193+
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
194+
accelerator = "mps"
195+
else:
196+
accelerator = "cpu"
197+
190198
if devices == "auto":
191-
devices = "1" # default to 1 device if 'auto' is specified
199+
if accelerator == "cuda" and torch.cuda.device_count() > 0 or accelerator == "mps" or accelerator == "cpu":
200+
devices = "1"
201+
else:
202+
raise ValueError(f"Cannot default to '1' device for accelerator='{accelerator}'")
192203
if accelerator == "gpu":
193204
parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True)
194205
elif accelerator == "cuda":
@@ -197,13 +208,6 @@ def _get_num_processes(accelerator: str, devices: str) -> int:
197208
parsed_devices = MPSAccelerator.parse_devices(devices)
198209
elif accelerator == "tpu":
199210
raise ValueError("Launching processes for TPU through the CLI is not supported.")
200-
elif accelerator == "auto" or accelerator is None:
201-
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
202-
parsed_devices = CUDAAccelerator.parse_devices(devices)
203-
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
204-
parsed_devices = MPSAccelerator.parse_devices(devices)
205-
else:
206-
return CPUAccelerator.parse_devices(devices)
207211
else:
208212
return CPUAccelerator.parse_devices(devices)
209213
return len(parsed_devices) if parsed_devices is not None else 0

0 commit comments

Comments
 (0)