diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 82da5248e2d6d..7f44de0589938 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -453,10 +453,11 @@ def _check_strategy_and_fallback(self) -> None: if ( strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy - ) and self._accelerator_flag not in ("cuda", "gpu"): + ) and not (self._accelerator_flag in ("cuda", "gpu") or isinstance(self._accelerator_flag, CUDAAccelerator)): raise ValueError( - f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but got:" - f" {self._accelerator_flag}" + f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but received " + f"`accelerator={self._accelerator_flag!r}`. Please set `accelerator='cuda'`, `accelerator='gpu'`," + " or pass a `CUDAAccelerator()` instance to use FSDP." ) if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods(): raise ValueError( diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index d079a1c7b9a1e..f3d98cf444c36 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -582,6 +582,11 @@ class AcceleratorSubclass(CPUAccelerator): Trainer(accelerator=AcceleratorSubclass(), strategy=FSDPStrategySubclass()) +@RunIf(min_cuda_gpus=1) +def test_check_fsdp_strategy_and_fallback_with_cudaaccelerator(): + Trainer(strategy="fsdp", accelerator=CUDAAccelerator()) + + @mock.patch.dict(os.environ, {}, clear=True) def test_unsupported_tpu_choice(xla_available, tpu_available): # if user didn't set strategy, _Connector will choose the SingleDeviceXLAStrategy or XLAStrategy