3333import torch .utils .hooks as hooks
3434from huggingface_hub import split_torch_state_dict_into_shards
3535
36+ from accelerate .utils .dataclasses import FP8BackendType
37+
3638from .checkpointing import load_accelerator_state , load_custom_state , save_accelerator_state , save_custom_state
3739from .data_loader import DataLoaderDispatcher , prepare_data_loader , skip_first_batches
3840from .logging import get_logger
@@ -301,6 +303,7 @@ def __init__(
301303 self .project_configuration = ProjectConfiguration (project_dir = project_dir )
302304 if project_dir is not None and self .project_dir is None :
303305 self .project_configuration .set_directories (project_dir )
306+
304307 if mixed_precision is not None :
305308 mixed_precision = str (mixed_precision )
306309 if mixed_precision not in PrecisionType :
@@ -458,27 +461,34 @@ def __init__(
458461
459462 # Check for automatic FP8 recipe creation
460463 if self .fp8_enabled and not self .has_fp8_handler :
461- # Prioritize AO -> TE -> MSAMP
462- if is_torchao_available ():
463- logger .info ("Found `torchao` installed, using it for FP8 training." )
464+ if self .fp8_backend == FP8BackendType .AO :
464465 self .ao_recipe_handler = AORecipeKwargs ()
465- elif is_transformer_engine_available ():
466- logger .info ("Found `transformer-engine` installed, using it for FP8 training." )
466+ elif self .fp8_backend == FP8BackendType .TE :
467467 self .te_recipe_handler = TERecipeKwargs ()
468- elif is_msamp_available ():
469- logger .info ("Found `msamp` installed, using it for FP8 training." )
468+ elif self .fp8_backend == FP8BackendType .MSAMP :
470469 self .msamp_recipe_handler = MSAMPRecipeKwargs ()
471- else :
472- raise ImportError (
473- "Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. "
474- "Valid backends are: `torchao`, `transformer-engine`, and `msamp`."
475- )
470+ elif self .fp8_backend == FP8BackendType .NO :
471+ # Prioritize AO -> TE -> MSAMP
472+ if is_torchao_available ():
473+ logger .info ("Found `torchao` installed, using it for FP8 training." )
474+ self .ao_recipe_handler = AORecipeKwargs ()
475+ elif is_transformer_engine_available ():
476+ logger .info ("Found `transformer-engine` installed, using it for FP8 training." )
477+ self .te_recipe_handler = TERecipeKwargs ()
478+ elif is_msamp_available ():
479+ logger .info ("Found `msamp` installed, using it for FP8 training." )
480+ self .msamp_recipe_handler = MSAMPRecipeKwargs ()
481+ else :
482+ raise ImportError (
483+ "Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. "
484+ "Valid backends are: `torchao`, `transformer-engine`, and `msamp`."
485+ )
476486 self .has_fp8_handler = True
477487
478488 self .delayed_fp8_autocast = False
479489 if self .has_fp8_handler :
480490 # We already check if FP8 is available during `self.state`
481- if mixed_precision != "fp8" and (
491+ if not self . fp8_enabled and (
482492 self .distributed_type not in (DistributedType .FSDP , DistributedType .DEEPSPEED )
483493 ):
484494 raise ValueError ("Passing in an FP8 configuration requires setting `mixed_precision='fp8'`." )
@@ -488,7 +498,11 @@ def __init__(
488498 )
489499
490500 # TODO: S1ro - this is probably gonna be a problem with other fp8 backends too
491- if self .fp8_backend == "AO" and self .state .fsdp_plugin .cpu_ram_efficient_loading :
501+ if (
502+ self .fp8_backend == FP8BackendType .AO
503+ and self .state .distributed_type == DistributedType .FSDP
504+ and self .state .fsdp_plugin .cpu_ram_efficient_loading
505+ ):
492506 raise ValueError (
493507 "torchao with FSDP2 and cpu_ram_efficient_loading is not supported, setting `cpu_ram_efficient_loading` to False will fix the issue and work as intended."
494508 )
@@ -572,7 +586,7 @@ def __init__(
572586 elif self .fp8_enabled :
573587 # We always enable `native_amp` for FP8
574588 self .native_amp = True
575- if self .fp8_backend == " MSAMP" :
589+ if self .fp8_backend == FP8BackendType . MSAMP :
576590 if self .distributed_type == DistributedType .FSDP :
577591 raise NotImplementedError (
578592 "`accelerate` + `MS-AMP` + `FSDP` is not supported at this time. "
@@ -1419,9 +1433,9 @@ def prepare(self, *args, device_placement=None):
14191433 "You are using lower version of PyTorch(< 2.7.0) with ipex acceleration on Intel CPU or XPU, Intel has upstreamed most of the optimizations into stock PyTorch from 2.7.0, we enourage you to install the latest stock PyTorch and enjoy the out-of-experience on Intel CPU/XPU."
14201434 )
14211435 args = self ._prepare_ipex (* args )
1422- if self .fp8_backend == "TE" :
1436+ if self .fp8_backend == FP8BackendType . TE :
14231437 args = self ._prepare_te (* args )
1424- elif self .fp8_backend == "AO" :
1438+ elif self .fp8_backend == FP8BackendType . AO :
14251439 args = self ._prepare_ao (* args )
14261440 if self .distributed_type == DistributedType .DEEPSPEED :
14271441 result = self ._prepare_deepspeed (* args )
@@ -1430,7 +1444,7 @@ def prepare(self, *args, device_placement=None):
14301444 elif self .is_fsdp2 :
14311445 result = self ._prepare_fsdp2 (* args )
14321446 else :
1433- if self .fp8_backend == " MSAMP" :
1447+ if self .fp8_backend == FP8BackendType . MSAMP :
14341448 args , device_placement = self ._prepare_msamp (* args , device_placement = device_placement )
14351449 result = tuple (
14361450 self ._prepare_one (obj , first_pass = True , device_placement = d ) for obj , d in zip (args , device_placement )
@@ -1570,7 +1584,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
15701584 model ._original_forward = model .forward
15711585 autocast_context = get_mixed_precision_context_manager (self .native_amp , self .autocast_handler )
15721586 # NOTE: MS-AMP adds `__func__` already to `model.forward`, so we should always use `model.forward`
1573- if self .fp8_backend == " MSAMP" or not hasattr (model .forward , "__func__" ):
1587+ if self .fp8_backend == FP8BackendType . MSAMP or not hasattr (model .forward , "__func__" ):
15741588 model_forward_func = model .forward
15751589 model .forward = convert_outputs_to_fp32 (autocast_context (model_forward_func ))
15761590 else :
@@ -1580,7 +1594,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
15801594 model .forward = MethodType (convert_outputs_to_fp32 (model .forward .__func__ ), model )
15811595
15821596 # We prepare TE after, allowing for bf16 autocast to happen first
1583- if self .fp8_backend == "TE" and not self .delayed_fp8_autocast :
1597+ if self .fp8_backend == FP8BackendType . TE and not self .delayed_fp8_autocast :
15841598 model = apply_fp8_autowrap (model , self .te_recipe_handler or self .fp8_recipe_handler )
15851599
15861600 if (getattr (model , "is_loaded_in_8bit" , False ) or getattr (model , "is_loaded_in_4bit" , False )) and getattr (
@@ -1806,7 +1820,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
18061820 elif self .distributed_type == DistributedType .XLA and self .state .fork_launched :
18071821 model = xmp .MpModelWrapper (model ).to (self .device )
18081822 # Now we can apply the FP8 autocast
1809- if self .fp8_backend == "TE" and self .delayed_fp8_autocast :
1823+ if self .fp8_backend == FP8BackendType . TE and self .delayed_fp8_autocast :
18101824 model = apply_fp8_autowrap (model , self .te_recipe_handler or self .fp8_recipe_handler )
18111825 # torch.compile should be called last and only if the model isn't already compiled
18121826 if self .state .dynamo_plugin .backend != DynamoBackend .NO and not is_compiled_module (model ):
@@ -1884,7 +1898,7 @@ def _prepare_deepspeed(self, *args):
18841898 import deepspeed
18851899
18861900 ds_initialize = deepspeed .initialize
1887- if self .fp8_backend == " MSAMP" :
1901+ if self .fp8_backend == FP8BackendType . MSAMP :
18881902 # MS-AMP requires DeepSpeed patches
18891903 from msamp import deepspeed as msamp_deepspeed
18901904
@@ -2022,7 +2036,7 @@ def _prepare_deepspeed(self, *args):
20222036
20232037 if model is not None :
20242038 # If we are using FP8, we need to apply the autowrap now
2025- if self .fp8_backend == "TE" :
2039+ if self .fp8_backend == FP8BackendType . TE :
20262040 model = apply_fp8_autowrap (model , self .fp8_recipe_handler )
20272041 # if the model is an MOE, set the appropriate MOE layers as leaf Z3 modules
20282042 deepspeed_plugin .set_moe_leaf_modules (model )
@@ -2479,7 +2493,7 @@ def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=N
24792493 device_placement = self .device_placement
24802494 # NOTE: Special case with MS-AMP we do *not* pass in the scaler explicitly to the `AcceleratedOptimizer`,
24812495 # Their optimizer handles it for us.
2482- scaler = None if self .fp8_backend == " MSAMP" else self .scaler
2496+ scaler = None if self .fp8_backend == FP8BackendType . MSAMP else self .scaler
24832497 optimizer = AcceleratedOptimizer (optimizer , device_placement = device_placement , scaler = scaler )
24842498 self ._optimizers .append (optimizer )
24852499 return optimizer
@@ -3668,7 +3682,7 @@ def _get_named_parameters(self, *args, drop_refs=False):
36683682
36693683 # we need this bit as `WeightWithDynamic...` returns 0 when `data_ptr()` is called,
36703684 # the underlying pointer is actually hidden in `_tensor` attribute
3671- if self .fp8_backend == "AO" :
3685+ if self .fp8_backend == FP8BackendType . AO :
36723686 from torchao .float8 .fsdp_utils import WeightWithDynamicFloat8CastTensor
36733687
36743688 accessor_mapping [WeightWithDynamicFloat8CastTensor ] = "_tensor"
@@ -3977,17 +3991,18 @@ def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> None:
39773991 )
39783992
39793993 @property
3980- def fp8_backend (self ):
3994+ def fp8_backend (self ) -> FP8BackendType :
39813995 "Returns the configured backend for training in FP8"
39823996 if self .has_fp8_handler :
39833997 if self .fp8_recipe_handler is not None :
3984- return self .fp8_recipe_handler .backend
3998+ return FP8BackendType ( self .fp8_recipe_handler .backend )
39853999 elif self .ao_recipe_handler is not None :
3986- return "AO"
4000+ return FP8BackendType . AO
39874001 elif self .te_recipe_handler is not None :
3988- return "TE"
4002+ return FP8BackendType . TE
39894003 elif self .msamp_recipe_handler is not None :
3990- return " MSAMP"
4004+ return FP8BackendType . MSAMP
39914005 elif self .state .deepspeed_plugin is not None and self .state .deepspeed_plugin .enable_msamp :
3992- return "MSAMP"
3993- return None
4006+ return FP8BackendType .MSAMP
4007+
4008+ return FP8BackendType (parse_choice_from_env ("ACCELERATE_FP8_BACKEND" , "NO" ))
0 commit comments