@@ -463,11 +463,16 @@ def _check_strategy_and_fallback(self) -> None:
463463
464464 if (
465465 strategy_flag in FSDPStrategy .get_registered_strategies () or type (self ._strategy_flag ) is FSDPStrategy
466- ) and self ._accelerator_flag not in ("cuda" , "gpu" ):
466+ ) and self ._accelerator_flag not in ("cuda" , "gpu" ) and isinstance ( self . _accelerator_flag , str ) :
467467 raise ValueError (
468468 f"The strategy `{ FSDPStrategy .strategy_name } ` requires a GPU accelerator, but got:"
469469 f" { self ._accelerator_flag } "
470470 )
471+ elif isinstance (self ._accelerator_flag , Accelerator ):
472+ Warning (
473+ f"Using a custom accelerator `{ self ._accelerator_flag .__class__ .__name__ } `."
474+ f" Please ensure it is compatible with the selected strategy `{ strategy_flag } `."
475+ )
471476 if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch .multiprocessing .get_all_start_methods ():
472477 raise ValueError (
473478 f"You selected `Trainer(strategy='{ strategy_flag } ')` but process forking is not supported on this"
@@ -501,7 +506,7 @@ def _check_and_init_precision(self) -> Precision:
501506 if isinstance (self .strategy , DeepSpeedStrategy ):
502507 return DeepSpeedPrecision (self ._precision_flag ) # type: ignore[arg-type]
503508 if isinstance (self .strategy , FSDPStrategy ):
504- return FSDPPrecision (precision = self ._precision_input , device = self ._accelerator_flag .get_device () if isinstance (self ._accelerator_flag , Accelerator ) else None )
509+ return FSDPPrecision (precision = self ._precision_flag , device = self ._accelerator_flag .get_device () if isinstance (self ._accelerator_flag , Accelerator ) else None )
505510 if self ._precision_flag in ("16-true" , "bf16-true" ):
506511 return HalfPrecision (self ._precision_flag ) # type: ignore
507512 if self ._precision_flag == "32-true" :
0 commit comments