Skip to content

Commit e77cedf

Browse files
committed
use a provate method to avoid dependency errors
1 parent a2e18d3 commit e77cedf

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

src/lightning/fabric/cli.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from lightning_utilities.core.imports import RequirementCache
2222
from typing_extensions import get_args
2323

24-
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
24+
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator, XLAAccelerator
2525
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS
2626
from lightning.fabric.strategies import STRATEGY_REGISTRY
2727
from lightning.fabric.utilities.consolidate_checkpoint import _process_cli_args
@@ -37,6 +37,17 @@
3737
_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu", "auto")
3838

3939

40+
def _choose_auto_accelerator() -> str:
41+
"""Choose the best available accelerator for the current environment."""
42+
if CUDAAccelerator.is_available():
43+
return "cuda"
44+
if MPSAccelerator.is_available():
45+
return "mps"
46+
if XLAAccelerator.is_available():
47+
return "tpu"
48+
return "cpu"
49+
50+
4051
def _get_supported_strategies() -> list[str]:
4152
"""Returns strategy choices from the registry, with the ones removed that are incompatible to be launched from the
4253
CLI or ones that require further configuration by the user."""
@@ -187,10 +198,9 @@ def _set_env_variables(args: Namespace) -> None:
187198

188199
def _get_num_processes(accelerator: str, devices: str) -> int:
189200
"""Parse the `devices` argument to determine how many processes need to be launched on the current machine."""
190-
from lightning.pytorch.trainer.connectors.accelerator_connector import _AcceleratorConnector
191201

192202
if accelerator == "auto" or accelerator is None:
193-
accelerator = _AcceleratorConnector._choose_auto_accelerator()
203+
accelerator = _choose_auto_accelerator()
194204
if devices == "auto":
195205
if accelerator == "cuda" or accelerator == "mps" or accelerator == "cpu":
196206
devices = "1"

0 commit comments

Comments
 (0)