Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,11 @@ def _check_strategy_and_fallback(self) -> None:

if (
strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy
) and self._accelerator_flag not in ("cuda", "gpu"):
) and not (self._accelerator_flag in ("cuda", "gpu") or isinstance(self._accelerator_flag, CUDAAccelerator)):
raise ValueError(
f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but got:"
f" {self._accelerator_flag}"
f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but received "
f"`accelerator={self._accelerator_flag!r}`. Please set `accelerator='cuda'`, `accelerator='gpu'`,"
" or pass a `CUDAAccelerator()` instance to use FSDP."
)
if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods():
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,11 @@ class AcceleratorSubclass(CPUAccelerator):
Trainer(accelerator=AcceleratorSubclass(), strategy=FSDPStrategySubclass())


@RunIf(min_cuda_gpus=1)
def test_check_fsdp_strategy_and_fallback_with_cudaaccelerator():
Trainer(strategy="fsdp", accelerator=CUDAAccelerator())


@mock.patch.dict(os.environ, {}, clear=True)
def test_unsupported_tpu_choice(xla_available, tpu_available):
# if user didn't set strategy, _Connector will choose the SingleDeviceXLAStrategy or XLAStrategy
Expand Down
Loading