@@ -141,6 +141,8 @@ def __init__(
141141 self ._accelerator_flag = self ._choose_auto_accelerator ()
142142 elif self ._accelerator_flag == "gpu" :
143143 self ._accelerator_flag = self ._choose_gpu_accelerator_backend ()
144+ elif isinstance (self ._accelerator_flag , Accelerator ):
145+ pass # do nothing
144146
145147 self ._set_parallel_devices_and_init_accelerator ()
146148
@@ -461,7 +463,7 @@ def _check_and_init_precision(self) -> Precision:
461463 if isinstance (self .strategy , DeepSpeedStrategy ):
462464 return DeepSpeedPrecision (self ._precision_input ) # type: ignore
463465 if isinstance (self .strategy , FSDPStrategy ):
464- return FSDPPrecision (precision = self ._precision_input ) # type: ignore[arg-type]
466+ return FSDPPrecision (precision = self ._precision_input , device = self . _accelerator_flag . get_device () if isinstance ( self . _accelerator_flag , Accelerator ) else None ) # type: ignore[arg-type]
465467 mp_precision_supported = ("32-true" , "bf16-mixed" , "bf16-true" , "16-true" )
466468 if isinstance (self .strategy , ModelParallelStrategy ) and self ._precision_input not in mp_precision_supported :
467469 raise ValueError (
@@ -493,6 +495,8 @@ def _check_and_init_precision(self) -> Precision:
493495 else "Using bfloat16 Automatic Mixed Precision (AMP)"
494496 )
495497 device = "cpu" if self ._accelerator_flag == "cpu" else "cuda"
498+ if isinstance (self ._accelerator_flag , Accelerator ):
499+ device = self ._accelerator_flag .get_device ()
496500 return MixedPrecision (precision = self ._precision_input , device = device ) # type: ignore[arg-type]
497501
498502 raise RuntimeError ("No precision set" )
0 commit comments