diff --git a/docs/source/concept_guides/low_precision_training.md b/docs/source/concept_guides/low_precision_training.md index e7527cce758..8a2c960fe80 100644 --- a/docs/source/concept_guides/low_precision_training.md +++ b/docs/source/concept_guides/low_precision_training.md @@ -61,6 +61,12 @@ If we notice in the chart mentioned earlier, TE simply casts the computation lay ## `MS-AMP` + + +MS-AMP is no longer actively maintained and has known compatibility issues with newer CUDA versions (12.x+) and PyTorch builds. We recommend using `TransformersEngine` or `torchao` instead for FP8 training. + + + MS-AMP takes a different approach to `TransformersEngine` by providing three different optimization levels to convert more operations in FP8 or FP16. * The base optimization level (`O1`), passes communications of the weights (such as in DDP) in FP8, stores the weights of the model in FP16, and leaves the optimizer states in FP32. The main benefit of this optimization level is that we can reduce the communication bandwidth by essentially half. Additionally, more GPU memory is saved due to 1/2 of everything being cast in FP8, and the weights being cast to FP16. Notably, both the optimizer states remain in FP32. diff --git a/docs/source/usage_guides/low_precision_training.md b/docs/source/usage_guides/low_precision_training.md index 2d81dcb8885..b07285fffa8 100644 --- a/docs/source/usage_guides/low_precision_training.md +++ b/docs/source/usage_guides/low_precision_training.md @@ -39,7 +39,7 @@ from accelerate import Accelerator accelerator = Accelerator(mixed_precision="fp8") ``` -By default, if `MS-AMP` is available in your environment, Accelerate will automatically utilize it as a backend. To specify it yourself (and customize other parts of the FP8 mixed precision setup), you can utilize one of the `RecipeKwargs` dataclasses such as [`utils.AORecipeKwargs`], [`utils.TERecipeKwargs`], or [`utils.MSAMPRecipeKwargs`]; you can also clarify it in your config `yaml`/during `accelerate launch`: +To specify a backend (and customize other parts of the FP8 mixed precision setup), you can utilize one of the `RecipeKwargs` dataclasses such as [`utils.AORecipeKwargs`], [`utils.TERecipeKwargs`], or [`utils.MSAMPRecipeKwargs`]; you can also clarify it in your config `yaml`/during `accelerate launch`. We recommend using `TransformersEngine` or `torchao` for new projects: ```{python} from accelerate import Accelerator @@ -67,6 +67,12 @@ fp8_config: ## Configuring MS-AMP + + +MS-AMP is no longer actively maintained and has known compatibility issues with newer CUDA versions (12.x+) and PyTorch builds. We recommend using `TransformersEngine` or `torchao` instead for FP8 training. + + + Of the two, `MS-AMP` is traditionally the easier one to configure as there is only a single argument: the optimization level. Currently two levels of optimization are supported in the Accelerate integration, `"O1"` and `"O2"` (using the letter 'o', not zero). diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 5abf98b305a..23bb2138215 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -2625,6 +2625,12 @@ def _prepare_device_mesh(self): return self.torch_device_mesh def _prepare_msamp(self, *args, device_placement): + warnings.warn( + "MS-AMP is deprecated and will be removed in a future version of Accelerate. " + "Please use `'te'` (Transformer Engine) or `'torchao'` as the backend for FP8 " + "mixed precision training instead.", + FutureWarning, + ) if not is_msamp_available(): raise ImportError( "MS-AMP was not found on your system. Please ensure that MS-AMP is available "