Skip to content

Commit 6487295

Browse files
committed
fsdp2 tests started
1 parent 28a2359 commit 6487295

File tree

2 files changed

+833
-3
lines changed

2 files changed

+833
-3
lines changed

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,11 +453,19 @@ def _check_strategy_and_fallback(self) -> None:
453453
# TODO this logic should apply to both str and object config
454454
strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag
455455

456-
if (
456+
is_fsdp1_str = (
457457
strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy
458-
) and not (self._accelerator_flag in ("cuda", "gpu") or isinstance(self._accelerator_flag, CUDAAccelerator)):
458+
)
459+
is_fsdp2_str = (
460+
strategy_flag in FSDP2Strategy.get_registered_strategies() or type(self._strategy_flag) is FSDP2Strategy
461+
)
462+
463+
if (is_fsdp1_str or is_fsdp2_str) and not (
464+
self._accelerator_flag in ("cuda", "gpu") or isinstance(self._accelerator_flag, CUDAAccelerator)
465+
):
466+
strategy_name = FSDP2Strategy.strategy_name if is_fsdp2_str else FSDPStrategy.strategy_name
459467
raise ValueError(
460-
f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but received "
468+
f"The strategy `{strategy_name}` requires a GPU accelerator, but received "
461469
f"`accelerator={self._accelerator_flag!r}`. Please set `accelerator='cuda'`, `accelerator='gpu'`,"
462470
" or pass a `CUDAAccelerator()` instance to use FSDP."
463471
)

0 commit comments

Comments
 (0)