Skip to content

Commit 937b6ea

Browse files
committed
added torchao params as cli launch params
1 parent ee8011d commit 937b6ea

File tree

5 files changed

+42
-7
lines changed

5 files changed

+42
-7
lines changed

src/accelerate/accelerator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2042,7 +2042,6 @@ def _prepare_ao(self, *args):
20422042
if (
20432043
self.is_fsdp2
20442044
and len(optimizers) > 0
2045-
and self.ao_recipe_handler is not None
20462045
and self.ao_recipe_handler.config.enable_fsdp_float8_all_gather
20472046
):
20482047
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp

src/accelerate/commands/config/cluster.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -794,11 +794,11 @@ def get_cluster_input():
794794
)
795795
if mixed_precision == "fp8":
796796
if not is_fp8_available():
797-
raise ValueError("FP8 (either Transformer Engine or MSAMP) is not installed on this machine.")
797+
raise ValueError("FP8 (either TorchAO, Transformer Engine or MSAMP) is not installed on this machine.")
798798
fp8_config = {}
799799
fp8_config["backend"] = _ask_options(
800800
"Which FP8 backend do you want to use?",
801-
["te", "msamp"],
801+
["ao", "te", "msamp"],
802802
_convert_fp8_backend,
803803
)
804804
if fp8_config["backend"] == "TE":
@@ -870,6 +870,20 @@ def get_cluster_input():
870870
lambda x: "O1" if x == 0 else "O2",
871871
default=1,
872872
)
873+
874+
elif fp8_config["backend"] == "AO":
875+
if not is_torch_ao_available():
876+
raise ValueError("TorchAO was selected, but it is not installed on this machine.")
877+
fp8_config["enable_fsdp_float8_all_gather"] = _ask_field(
878+
"Do you want to enable FSDP2 float8 all gather? This is recommended for better performance if using FSDP2. [YES/no]: ",
879+
_convert_yes_no_to_bool,
880+
default=True,
881+
)
882+
fp8_config["pad_inner_dim"] = _ask_field(
883+
"Do you want to pad the inner dimension of weight matrices to multiples of 16 before float8 matmuls? Required for _scaled_mm which has strict alignment requirements. Note: padding may cause memory spikes. [YES/no]: ",
884+
_convert_yes_no_to_bool,
885+
default=True,
886+
)
873887

874888
if use_dynamo and mixed_precision == "no" and not use_cpu:
875889
print(

src/accelerate/commands/config/config_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def _convert_sagemaker_distributed_mode(value):
104104

105105
def _convert_fp8_backend(value):
106106
value = int(value)
107-
return FP8BackendType(["TE", "MSAMP"][value])
107+
return FP8BackendType(["AO", "TE", "MSAMP"][value])
108108

109109

110110
def _convert_yes_no_to_bool(value):

src/accelerate/commands/launch.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -667,8 +667,8 @@ def launch_command_parser(subparsers=None):
667667
fp8_args.add_argument(
668668
"--fp8_backend",
669669
type=str,
670-
choices=["te", "msamp"],
671-
help="Choose a backend to train with FP8 (te: TransformerEngine, msamp: MS-AMP)",
670+
choices=["ao", "te", "msamp"],
671+
help="Choose a backend to train with FP8 (ao: TorchAO,te: TransformerEngine, msamp: MS-AMP)",
672672
)
673673
fp8_args.add_argument(
674674
"--fp8_use_autocast_during_eval",
@@ -721,6 +721,18 @@ def launch_command_parser(subparsers=None):
721721
choices=["O1", "O2"],
722722
help="What level of 8-bit collective communication should be used with MS-AMP (useful only when `--fp8_backend=msamp` is passed).",
723723
)
724+
fp8_args.add_argument(
725+
"--fp8_enable_fsdp_float8_all_gather",
726+
default="true",
727+
type=str_to_bool,
728+
help="Whether to enable FSDP float8 all gather (useful only when `--fp8_backend=ao` is passed).",
729+
)
730+
fp8_args.add_argument(
731+
"--fp8_pad_inner_dim",
732+
default="true",
733+
type=str_to_bool,
734+
help="Whether to pad the inner dimension for FP8 GEMMs (useful only when `--fp8_backend=ao` is passed).",
735+
)
724736

725737
# AWS arguments
726738
aws_args = parser.add_argument_group("AWS Arguments", "Arguments related to AWS.")

src/accelerate/utils/dataclasses.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,17 @@ def __post_init__(self):
335335
if self.config is None:
336336
from torchao.float8 import Float8LinearConfig
337337

338-
self.config = Float8LinearConfig(pad_inner_dim=True, enable_fsdp_float8_all_gather=True)
338+
env_prefix = "ACCELERATE_FP8_"
339+
# Check environment variables for overrides
340+
pad_inner_dim = parse_flag_from_env(env_prefix + "PAD_INNER_DIM", default=True)
341+
enable_fsdp_float8_all_gather = parse_flag_from_env(
342+
env_prefix + "ENABLE_FSDP_FLOAT8_ALL_GATHER", default=True
343+
)
344+
345+
self.config = Float8LinearConfig(
346+
pad_inner_dim=pad_inner_dim,
347+
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
348+
)
339349

340350

341351
@dataclass

0 commit comments

Comments
 (0)