File tree Expand file tree Collapse file tree 2 files changed +11
-1
lines changed Expand file tree Collapse file tree 2 files changed +11
-1
lines changed Original file line number Diff line number Diff line change @@ -429,7 +429,7 @@ def _check_strategy_and_fallback(self) -> None:
429
429
f" platform. We recommed `Fabric(strategy='ddp_spawn')` instead."
430
430
)
431
431
if (
432
- strategy_flag in _FSDP_ALIASES or isinstance (self ._strategy_flag , FSDPStrategy )
432
+ strategy_flag in _FSDP_ALIASES or type (self ._strategy_flag ) is FSDPStrategy
433
433
) and self ._accelerator_flag not in ("cuda" , "gpu" ):
434
434
raise ValueError (
435
435
"You selected the FSDP strategy but FSDP is only available on GPU. Set `Fabric(accelerator='gpu', ...)`"
Original file line number Diff line number Diff line change @@ -978,6 +978,16 @@ def test_fsdp_unsupported_on_cpu(_):
978
978
with pytest .raises (ValueError , match = "You selected the FSDP strategy but FSDP is only available on GPU" ):
979
979
_Connector (accelerator = "cpu" , strategy = "fsdp" )
980
980
981
+ class FSDPStrategySubclass (FSDPStrategy ):
982
+ pass
983
+
984
+ class AcceleratorSubclass (CPUAccelerator ):
985
+ pass
986
+
987
+ # we allow subclasses of FSDPStrategy to be used with other accelerators
988
+ _Connector (accelerator = "cpu" , strategy = FSDPStrategySubclass ())
989
+ _Connector (accelerator = AcceleratorSubclass (), strategy = FSDPStrategySubclass ())
990
+
981
991
982
992
def test_connector_defaults_match_fabric_defaults ():
983
993
"""Test that the default values for the init arguments of Connector match the ones in Fabric."""
You can’t perform that action at this time.
0 commit comments