Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from typing import Any
import functools
import warnings

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -279,7 +280,7 @@ def setup(self):
config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy
)

def minimal_policy(self, with_context=False):
def minimal_policy(self, with_context=False, with_quantization=False):
"""Helper for creating minimal checkpoint policies."""
names = [
"query_proj",
Expand All @@ -294,6 +295,8 @@ def minimal_policy(self, with_context=False):
]
if with_context:
names.append("context")
if with_quantization:
names.append("quantization")
return jax.checkpoint_policies.save_only_these_names(*names)

def get_remat_policy(self):
Expand All @@ -310,6 +313,14 @@ def get_remat_policy(self):
elif cfg.remat_policy == "minimal":
# save all except context
policy = self.minimal_policy()
elif cfg.remat_policy == "minimal_with_quantization":
if cfg.scan_layers:
warnings.warn('Scan layers can introduce overhead to checkpointed values that in some configurations is slower than not checkpointing at all. If you are using scan layers, benchmark with and without quantization checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is beneficial for performance.')
policy = self.minimal_policy(with_context=False, with_quantization=True)
elif cfg.remat_policy == "minimal_with_context_and_quantization":
if cfg.scan_layers:
warnings.warn('Scan layers can introduce overhead to checkpointed values that in some configurations is slower than not checkpointing at all. If you are using scan layers, benchmark with and without quantization checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is beneficial for performance.')
policy = self.minimal_policy(with_context=True, with_quantization=True)
elif cfg.remat_policy == "save_dot_with_context_except_mlp":
policy = jax.checkpoint_policies.save_only_these_names(
"query_proj",
Expand Down
5 changes: 4 additions & 1 deletion src/MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,10 @@ class TEWrapper(transformer_engine.jax.flax.module.TransformerEngineBase):
def generate_quantizer_set(self, postfix: str = ""):
OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient"
return super().generate_quantizer_set( # pytype: disable=wrong-keyword-args
postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, fp8_recipe=fp8_recipe
postfix=postfix,
variable_collection=OVERWRITE_WITH_GRADIENT,
quantization_checkpoint_name="quantization",
fp8_recipe=fp8_recipe
)

@nn.compact
Expand Down
Loading