Skip to content

Commit 8fc7b4a

Browse files
authored
Remove the requirement for FSDPStrategy subclasses to only support GPU (#19894)
1 parent 987c2c4 commit 8fc7b4a

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

src/lightning/fabric/connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def _check_strategy_and_fallback(self) -> None:
429429
f" platform. We recommed `Fabric(strategy='ddp_spawn')` instead."
430430
)
431431
if (
432-
strategy_flag in _FSDP_ALIASES or isinstance(self._strategy_flag, FSDPStrategy)
432+
strategy_flag in _FSDP_ALIASES or type(self._strategy_flag) is FSDPStrategy
433433
) and self._accelerator_flag not in ("cuda", "gpu"):
434434
raise ValueError(
435435
"You selected the FSDP strategy but FSDP is only available on GPU. Set `Fabric(accelerator='gpu', ...)`"

tests/tests_fabric/test_connector.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,16 @@ def test_fsdp_unsupported_on_cpu(_):
978978
with pytest.raises(ValueError, match="You selected the FSDP strategy but FSDP is only available on GPU"):
979979
_Connector(accelerator="cpu", strategy="fsdp")
980980

981+
class FSDPStrategySubclass(FSDPStrategy):
982+
pass
983+
984+
class AcceleratorSubclass(CPUAccelerator):
985+
pass
986+
987+
# we allow subclasses of FSDPStrategy to be used with other accelerators
988+
_Connector(accelerator="cpu", strategy=FSDPStrategySubclass())
989+
_Connector(accelerator=AcceleratorSubclass(), strategy=FSDPStrategySubclass())
990+
981991

982992
def test_connector_defaults_match_fabric_defaults():
983993
"""Test that the default values for the init arguments of Connector match the ones in Fabric."""

0 commit comments

Comments
 (0)