@@ -187,8 +187,19 @@ def _set_env_variables(args: Namespace) -> None:
187187
188188def _get_num_processes (accelerator : str , devices : str ) -> int :
189189 """Parse the `devices` argument to determine how many processes need to be launched on the current machine."""
190+ if accelerator == "auto" or accelerator is None :
191+ if torch .cuda .is_available () and torch .cuda .device_count () > 0 :
192+ accelerator = "cuda"
193+ elif torch .backends .mps .is_available () and torch .backends .mps .is_built ():
194+ accelerator = "mps"
195+ else :
196+ accelerator = "cpu"
197+
190198 if devices == "auto" :
191- devices = "1" # default to 1 device if 'auto' is specified
199+ if accelerator == "cuda" and torch .cuda .device_count () > 0 or accelerator == "mps" or accelerator == "cpu" :
200+ devices = "1"
201+ else :
202+ raise ValueError (f"Cannot default to '1' device for accelerator='{ accelerator } '" )
192203 if accelerator == "gpu" :
193204 parsed_devices = _parse_gpu_ids (devices , include_cuda = True , include_mps = True )
194205 elif accelerator == "cuda" :
@@ -197,13 +208,6 @@ def _get_num_processes(accelerator: str, devices: str) -> int:
197208 parsed_devices = MPSAccelerator .parse_devices (devices )
198209 elif accelerator == "tpu" :
199210 raise ValueError ("Launching processes for TPU through the CLI is not supported." )
200- elif accelerator == "auto" or accelerator is None :
201- if torch .cuda .is_available () and torch .cuda .device_count () > 0 :
202- parsed_devices = CUDAAccelerator .parse_devices (devices )
203- elif torch .backends .mps .is_available () and torch .backends .mps .is_built ():
204- parsed_devices = MPSAccelerator .parse_devices (devices )
205- else :
206- return CPUAccelerator .parse_devices (devices )
207211 else :
208212 return CPUAccelerator .parse_devices (devices )
209213 return len (parsed_devices ) if parsed_devices is not None else 0
0 commit comments