Skip to content

Commit 460c60c

Browse files
Fix: Allow trainer to accept CUDAAccelerator instance as accelerator with FSDP strategy (#20964)
* Add test for FSDP with CUDAAccelerator instance * update error message * update the test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b087c1a commit 460c60c

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,10 +453,11 @@ def _check_strategy_and_fallback(self) -> None:
453453

454454
if (
455455
strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy
456-
) and self._accelerator_flag not in ("cuda", "gpu"):
456+
) and not (self._accelerator_flag in ("cuda", "gpu") or isinstance(self._accelerator_flag, CUDAAccelerator)):
457457
raise ValueError(
458-
f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but got:"
459-
f" {self._accelerator_flag}"
458+
f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but received "
459+
f"`accelerator={self._accelerator_flag!r}`. Please set `accelerator='cuda'`, `accelerator='gpu'`,"
460+
" or pass a `CUDAAccelerator()` instance to use FSDP."
460461
)
461462
if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods():
462463
raise ValueError(

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,11 @@ class AcceleratorSubclass(CPUAccelerator):
582582
Trainer(accelerator=AcceleratorSubclass(), strategy=FSDPStrategySubclass())
583583

584584

585+
@RunIf(min_cuda_gpus=1)
586+
def test_check_fsdp_strategy_and_fallback_with_cudaaccelerator():
587+
Trainer(strategy="fsdp", accelerator=CUDAAccelerator())
588+
589+
585590
@mock.patch.dict(os.environ, {}, clear=True)
586591
def test_unsupported_tpu_choice(xla_available, tpu_available):
587592
# if user didn't set strategy, _Connector will choose the SingleDeviceXLAStrategy or XLAStrategy

0 commit comments

Comments
 (0)