Skip to content

Commit eec7b50

Browse files
committed
use auto detect accelerator and add auto to the unit tests
1 parent 2675c7b commit eec7b50

File tree

2 files changed

+5
-10
lines changed

2 files changed

+5
-10
lines changed

src/lightning/fabric/cli.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
2929
from lightning.fabric.utilities.distributed import _suggested_max_num_threads
3030
from lightning.fabric.utilities.load import _load_distributed_checkpoint
31+
from lightning.pytorch.trainer.connectors.accelerator_connector import _AcceleratorConnector
3132

3233
_log = logging.getLogger(__name__)
3334

@@ -188,15 +189,9 @@ def _set_env_variables(args: Namespace) -> None:
188189
def _get_num_processes(accelerator: str, devices: str) -> int:
189190
"""Parse the `devices` argument to determine how many processes need to be launched on the current machine."""
190191
if accelerator == "auto" or accelerator is None:
191-
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
192-
accelerator = "cuda"
193-
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
194-
accelerator = "mps"
195-
else:
196-
accelerator = "cpu"
197-
192+
accelerator = _AcceleratorConnector._choose_auto_accelerator()
198193
if devices == "auto":
199-
if accelerator == "cuda" and torch.cuda.device_count() > 0 or accelerator == "mps" or accelerator == "cpu":
194+
if accelerator == "cuda" or accelerator == "mps" or accelerator == "cpu":
200195
devices = "1"
201196
else:
202197
raise ValueError(f"Cannot default to '1' device for accelerator='{accelerator}'")

tests/tests_fabric/test_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_run_env_vars_unsupported_strategy(strategy, fake_script):
8585
assert f"Invalid value for '--strategy': '{strategy}'" in ioerr.getvalue()
8686

8787

88-
@pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1"])
88+
@pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1", "auto"])
8989
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
9090
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
9191
def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
@@ -97,7 +97,7 @@ def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
9797

9898

9999
@RunIf(mps=True)
100-
@pytest.mark.parametrize("accelerator", ["mps", "gpu"])
100+
@pytest.mark.parametrize("accelerator", ["mps", "gpu", "auto"])
101101
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
102102
def test_run_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
103103
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())

0 commit comments

Comments
 (0)