diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index aba50a8d6..1c95f79d4 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -18,6 +18,7 @@ from typing import Any import functools +import warnings import jax import jax.numpy as jnp @@ -279,7 +280,7 @@ def setup(self): config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy ) - def minimal_policy(self, with_context=False): + def minimal_policy(self, with_context=False, with_quantization=False): """Helper for creating minimal checkpoint policies.""" names = [ "query_proj", @@ -294,6 +295,8 @@ def minimal_policy(self, with_context=False): ] if with_context: names.append("context") + if with_quantization: + names.append("quantization") return jax.checkpoint_policies.save_only_these_names(*names) def get_remat_policy(self): @@ -310,6 +313,14 @@ def get_remat_policy(self): elif cfg.remat_policy == "minimal": # save all except context policy = self.minimal_policy() + elif cfg.remat_policy == "minimal_with_quantization": + if cfg.scan_layers: + warnings.warn('Scan layers can introduce overhead to checkpointed values that in some configurations is slower than not checkpointing at all. If you are using scan layers, benchmark with and without quantization checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is beneficial for performance.') + policy = self.minimal_policy(with_context=False, with_quantization=True) + elif cfg.remat_policy == "minimal_with_context_and_quantization": + if cfg.scan_layers: + warnings.warn('Scan layers can introduce overhead to checkpointed values that in some configurations is slower than not checkpointing at all. If you are using scan layers, benchmark with and without quantization checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is beneficial for performance.') + policy = self.minimal_policy(with_context=True, with_quantization=True) elif cfg.remat_policy == "save_dot_with_context_except_mlp": policy = jax.checkpoint_policies.save_only_these_names( "query_proj", diff --git a/src/MaxText/layers/quantizations.py b/src/MaxText/layers/quantizations.py index d0f9353b6..b3832466e 100644 --- a/src/MaxText/layers/quantizations.py +++ b/src/MaxText/layers/quantizations.py @@ -807,7 +807,10 @@ class TEWrapper(transformer_engine.jax.flax.module.TransformerEngineBase): def generate_quantizer_set(self, postfix: str = ""): OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" return super().generate_quantizer_set( # pytype: disable=wrong-keyword-args - postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, fp8_recipe=fp8_recipe + postfix=postfix, + variable_collection=OVERWRITE_WITH_GRADIENT, + quantization_checkpoint_name="quantization", + fp8_recipe=fp8_recipe ) @nn.compact