Skip to content

Commit 847ae58

Browse files
authored
Fix FP8 tests, enable FP8 to be used without direct Accelerator() configuring (#3677)
* single-gpu tests passing * install deepspeed in fp8 container * revert mixed_precision check
1 parent 6e104f3 commit 847ae58

File tree

6 files changed

+128
-55
lines changed

6 files changed

+128
-55
lines changed

benchmarks/fp8/transformer_engine/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ RUN pip install transformers evaluate datasets
77
RUN git clone https://github.com/huggingface/accelerate.git
88

99
RUN cd accelerate && \
10-
pip install -e . && \
10+
pip install -e .[deepspeed] && \
1111
cd benchmarks/fp8
1212

1313
RUN /bin/bash

examples/config_yaml_templates/fp8.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ fp8_config:
1111
fp8_format: E4M3
1212
interval: 1
1313
margin: 0
14-
override_linear_precision: (false, false, false)
14+
override_linear_precision: [false, false, false]
1515
# Generally this should always be set to `false` to have the most realistic fp8 eval performance
1616
use_autocast_during_eval: false
1717
# If using MS-AMP, we ignore all of the prior and set a opt_level
18-
#opt_level: O1
18+
#opt_level: O1

src/accelerate/accelerator.py

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
import torch.utils.hooks as hooks
3434
from huggingface_hub import split_torch_state_dict_into_shards
3535

36+
from accelerate.utils.dataclasses import FP8BackendType
37+
3638
from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
3739
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
3840
from .logging import get_logger
@@ -301,6 +303,7 @@ def __init__(
301303
self.project_configuration = ProjectConfiguration(project_dir=project_dir)
302304
if project_dir is not None and self.project_dir is None:
303305
self.project_configuration.set_directories(project_dir)
306+
304307
if mixed_precision is not None:
305308
mixed_precision = str(mixed_precision)
306309
if mixed_precision not in PrecisionType:
@@ -458,27 +461,34 @@ def __init__(
458461

459462
# Check for automatic FP8 recipe creation
460463
if self.fp8_enabled and not self.has_fp8_handler:
461-
# Prioritize AO -> TE -> MSAMP
462-
if is_torchao_available():
463-
logger.info("Found `torchao` installed, using it for FP8 training.")
464+
if self.fp8_backend == FP8BackendType.AO:
464465
self.ao_recipe_handler = AORecipeKwargs()
465-
elif is_transformer_engine_available():
466-
logger.info("Found `transformer-engine` installed, using it for FP8 training.")
466+
elif self.fp8_backend == FP8BackendType.TE:
467467
self.te_recipe_handler = TERecipeKwargs()
468-
elif is_msamp_available():
469-
logger.info("Found `msamp` installed, using it for FP8 training.")
468+
elif self.fp8_backend == FP8BackendType.MSAMP:
470469
self.msamp_recipe_handler = MSAMPRecipeKwargs()
471-
else:
472-
raise ImportError(
473-
"Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. "
474-
"Valid backends are: `torchao`, `transformer-engine`, and `msamp`."
475-
)
470+
elif self.fp8_backend == FP8BackendType.NO:
471+
# Prioritize AO -> TE -> MSAMP
472+
if is_torchao_available():
473+
logger.info("Found `torchao` installed, using it for FP8 training.")
474+
self.ao_recipe_handler = AORecipeKwargs()
475+
elif is_transformer_engine_available():
476+
logger.info("Found `transformer-engine` installed, using it for FP8 training.")
477+
self.te_recipe_handler = TERecipeKwargs()
478+
elif is_msamp_available():
479+
logger.info("Found `msamp` installed, using it for FP8 training.")
480+
self.msamp_recipe_handler = MSAMPRecipeKwargs()
481+
else:
482+
raise ImportError(
483+
"Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. "
484+
"Valid backends are: `torchao`, `transformer-engine`, and `msamp`."
485+
)
476486
self.has_fp8_handler = True
477487

478488
self.delayed_fp8_autocast = False
479489
if self.has_fp8_handler:
480490
# We already check if FP8 is available during `self.state`
481-
if mixed_precision != "fp8" and (
491+
if not self.fp8_enabled and (
482492
self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED)
483493
):
484494
raise ValueError("Passing in an FP8 configuration requires setting `mixed_precision='fp8'`.")
@@ -488,7 +498,11 @@ def __init__(
488498
)
489499

490500
# TODO: S1ro - this is probably gonna be a problem with other fp8 backends too
491-
if self.fp8_backend == "AO" and self.state.fsdp_plugin.cpu_ram_efficient_loading:
501+
if (
502+
self.fp8_backend == FP8BackendType.AO
503+
and self.state.distributed_type == DistributedType.FSDP
504+
and self.state.fsdp_plugin.cpu_ram_efficient_loading
505+
):
492506
raise ValueError(
493507
"torchao with FSDP2 and cpu_ram_efficient_loading is not supported, setting `cpu_ram_efficient_loading` to False will fix the issue and work as intended."
494508
)
@@ -572,7 +586,7 @@ def __init__(
572586
elif self.fp8_enabled:
573587
# We always enable `native_amp` for FP8
574588
self.native_amp = True
575-
if self.fp8_backend == "MSAMP":
589+
if self.fp8_backend == FP8BackendType.MSAMP:
576590
if self.distributed_type == DistributedType.FSDP:
577591
raise NotImplementedError(
578592
"`accelerate` + `MS-AMP` + `FSDP` is not supported at this time. "
@@ -1419,9 +1433,9 @@ def prepare(self, *args, device_placement=None):
14191433
"You are using lower version of PyTorch(< 2.7.0) with ipex acceleration on Intel CPU or XPU, Intel has upstreamed most of the optimizations into stock PyTorch from 2.7.0, we enourage you to install the latest stock PyTorch and enjoy the out-of-experience on Intel CPU/XPU."
14201434
)
14211435
args = self._prepare_ipex(*args)
1422-
if self.fp8_backend == "TE":
1436+
if self.fp8_backend == FP8BackendType.TE:
14231437
args = self._prepare_te(*args)
1424-
elif self.fp8_backend == "AO":
1438+
elif self.fp8_backend == FP8BackendType.AO:
14251439
args = self._prepare_ao(*args)
14261440
if self.distributed_type == DistributedType.DEEPSPEED:
14271441
result = self._prepare_deepspeed(*args)
@@ -1430,7 +1444,7 @@ def prepare(self, *args, device_placement=None):
14301444
elif self.is_fsdp2:
14311445
result = self._prepare_fsdp2(*args)
14321446
else:
1433-
if self.fp8_backend == "MSAMP":
1447+
if self.fp8_backend == FP8BackendType.MSAMP:
14341448
args, device_placement = self._prepare_msamp(*args, device_placement=device_placement)
14351449
result = tuple(
14361450
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
@@ -1570,7 +1584,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
15701584
model._original_forward = model.forward
15711585
autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler)
15721586
# NOTE: MS-AMP adds `__func__` already to `model.forward`, so we should always use `model.forward`
1573-
if self.fp8_backend == "MSAMP" or not hasattr(model.forward, "__func__"):
1587+
if self.fp8_backend == FP8BackendType.MSAMP or not hasattr(model.forward, "__func__"):
15741588
model_forward_func = model.forward
15751589
model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func))
15761590
else:
@@ -1580,7 +1594,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
15801594
model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)
15811595

15821596
# We prepare TE after, allowing for bf16 autocast to happen first
1583-
if self.fp8_backend == "TE" and not self.delayed_fp8_autocast:
1597+
if self.fp8_backend == FP8BackendType.TE and not self.delayed_fp8_autocast:
15841598
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
15851599

15861600
if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
@@ -1806,7 +1820,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
18061820
elif self.distributed_type == DistributedType.XLA and self.state.fork_launched:
18071821
model = xmp.MpModelWrapper(model).to(self.device)
18081822
# Now we can apply the FP8 autocast
1809-
if self.fp8_backend == "TE" and self.delayed_fp8_autocast:
1823+
if self.fp8_backend == FP8BackendType.TE and self.delayed_fp8_autocast:
18101824
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
18111825
# torch.compile should be called last and only if the model isn't already compiled
18121826
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
@@ -1884,7 +1898,7 @@ def _prepare_deepspeed(self, *args):
18841898
import deepspeed
18851899

18861900
ds_initialize = deepspeed.initialize
1887-
if self.fp8_backend == "MSAMP":
1901+
if self.fp8_backend == FP8BackendType.MSAMP:
18881902
# MS-AMP requires DeepSpeed patches
18891903
from msamp import deepspeed as msamp_deepspeed
18901904

@@ -2022,7 +2036,7 @@ def _prepare_deepspeed(self, *args):
20222036

20232037
if model is not None:
20242038
# If we are using FP8, we need to apply the autowrap now
2025-
if self.fp8_backend == "TE":
2039+
if self.fp8_backend == FP8BackendType.TE:
20262040
model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
20272041
# if the model is an MOE, set the appropriate MOE layers as leaf Z3 modules
20282042
deepspeed_plugin.set_moe_leaf_modules(model)
@@ -2479,7 +2493,7 @@ def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=N
24792493
device_placement = self.device_placement
24802494
# NOTE: Special case with MS-AMP we do *not* pass in the scaler explicitly to the `AcceleratedOptimizer`,
24812495
# Their optimizer handles it for us.
2482-
scaler = None if self.fp8_backend == "MSAMP" else self.scaler
2496+
scaler = None if self.fp8_backend == FP8BackendType.MSAMP else self.scaler
24832497
optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=scaler)
24842498
self._optimizers.append(optimizer)
24852499
return optimizer
@@ -3668,7 +3682,7 @@ def _get_named_parameters(self, *args, drop_refs=False):
36683682

36693683
# we need this bit as `WeightWithDynamic...` returns 0 when `data_ptr()` is called,
36703684
# the underlying pointer is actually hidden in `_tensor` attribute
3671-
if self.fp8_backend == "AO":
3685+
if self.fp8_backend == FP8BackendType.AO:
36723686
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
36733687

36743688
accessor_mapping[WeightWithDynamicFloat8CastTensor] = "_tensor"
@@ -3977,17 +3991,18 @@ def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> None:
39773991
)
39783992

39793993
@property
3980-
def fp8_backend(self):
3994+
def fp8_backend(self) -> FP8BackendType:
39813995
"Returns the configured backend for training in FP8"
39823996
if self.has_fp8_handler:
39833997
if self.fp8_recipe_handler is not None:
3984-
return self.fp8_recipe_handler.backend
3998+
return FP8BackendType(self.fp8_recipe_handler.backend)
39853999
elif self.ao_recipe_handler is not None:
3986-
return "AO"
4000+
return FP8BackendType.AO
39874001
elif self.te_recipe_handler is not None:
3988-
return "TE"
4002+
return FP8BackendType.TE
39894003
elif self.msamp_recipe_handler is not None:
3990-
return "MSAMP"
4004+
return FP8BackendType.MSAMP
39914005
elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp:
3992-
return "MSAMP"
3993-
return None
4006+
return FP8BackendType.MSAMP
4007+
4008+
return FP8BackendType(parse_choice_from_env("ACCELERATE_FP8_BACKEND", "NO"))

src/accelerate/utils/dataclasses.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,8 +616,10 @@ class FP8BackendType(str, enum.Enum):
616616
"""
617617

618618
# Subclassing str as well as Enum allows the `FP8BackendType` to be JSON-serializable out of the box.
619+
NO = "NO"
619620
TE = "TE"
620621
MSAMP = "MSAMP"
622+
AO = "AO"
621623

622624

623625
class ComputeEnvironment(str, enum.Enum):

src/accelerate/utils/launch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ def setup_fp8_env(args: argparse.Namespace, current_env: dict[str, str]):
8989
value = getattr(args, arg)
9090
if value is not None:
9191
if arg == "fp8_override_linear_precision":
92-
current_env[prefix + "FP8_OVERRIDE_FPROP"] = value[0]
93-
current_env[prefix + "FP8_OVERRIDE_DGRAD"] = value[1]
94-
current_env[prefix + "FP8_OVERRIDE_WGRAD"] = value[2]
92+
current_env[prefix + "FP8_OVERRIDE_FPROP"] = str(value[0])
93+
current_env[prefix + "FP8_OVERRIDE_DGRAD"] = str(value[1])
94+
current_env[prefix + "FP8_OVERRIDE_WGRAD"] = str(value[2])
9595
else:
9696
current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg))
9797
return current_env

0 commit comments

Comments
 (0)