From e9121fe9088191edef4413f39a7ba2fcc04eba33 Mon Sep 17 00:00:00 2001 From: Farzin Houshmand Date: Thu, 20 Nov 2025 08:00:24 -0800 Subject: [PATCH] This change introduces a pure JAX implementation of flash attention to Maxtext, designed as a drop-in replacement for the existing Pallas kernel. In this cl we set up the stage by integrating it with maxtext in fsdp mode. We have plans for further optimizations to close the gap with pallas using different techniques such as: iteration skipping, must_fuse, and memory space coloring. The new implementation is located in maxtext/src/maxtext/kernels/jax_flash_attention.py and can be enabled with the use_jax_splash config flag. To validate the implementation and compare it against the Tokamax kernel and the baseline dot-product attention, this change also introduces: A new test suite in google_mla_attention_test.py for correctness and performance comparison, particularly for FSDP cases. Refactored common MLA test utilities into attention_test_util.py. PiperOrigin-RevId: 834764107 --- src/MaxText/common_types.py | 1 + src/MaxText/configs/base.yml | 1 + src/MaxText/configs/types.py | 3 + src/MaxText/kernels/jax_flash_attention.py | 269 +++++++++++++++++++++ src/MaxText/layers/attention_op.py | 100 ++++---- tests/attention_test.py | 163 +------------ tests/attention_test_util.py | 236 ++++++++++++++++++ 7 files changed, 578 insertions(+), 195 deletions(-) create mode 100644 src/MaxText/kernels/jax_flash_attention.py create mode 100644 tests/attention_test_util.py diff --git a/src/MaxText/common_types.py b/src/MaxText/common_types.py index f26d02cb1..f2d630c7d 100644 --- a/src/MaxText/common_types.py +++ b/src/MaxText/common_types.py @@ -65,6 +65,7 @@ # expert_shard_attention_option EP_AS_CONTEXT = "context" +EP_AS_FSDP = "fsdp" DECODING_ACTIVE_SEQUENCE_INDICATOR = 1 diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index e949f8156..c0b9c9302 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -967,3 +967,4 @@ partial_rotary_factor: 1.0 # Use tokamax library for gmm kernel implementation use_tokamax_gmm: false use_tokamax_splash: false +use_jax_splash: false diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 9fd1bbf3b..8b6a1856a 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -451,6 +451,9 @@ class Attention(BaseModel): ragged_block_size: int = Field(256, description="Block size for ragged attention.") enable_padding_causal_mask: bool = Field(True, description="Temporary flag for TE padding.") use_tokamax_splash: bool = Field(False, description="Whether to use tokamax splash attention.") + use_jax_splash: bool = Field( + False, description="Whether to use jax splash attention." + ) class MoBa(BaseModel): diff --git a/src/MaxText/kernels/jax_flash_attention.py b/src/MaxText/kernels/jax_flash_attention.py new file mode 100644 index 000000000..25c7ebd40 --- /dev/null +++ b/src/MaxText/kernels/jax_flash_attention.py @@ -0,0 +1,269 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX implementation of Flash Attention.""" + +from typing import Optional, Tuple, Union + +import jax +import jax.numpy as jnp +from MaxText.kernels import splash_attention_kernel + +SegmentIds = splash_attention_kernel.SegmentIds + + +def flash_attention_block_masked( + q: jnp.ndarray, + k: jnp.ndarray, + v: jnp.ndarray, + segment_ids: SegmentIds | None, + block_kv: int, + block_q: int, + mask: jnp.ndarray, + mask_value: float, + cap: Optional[float] = None, + save_residuals: bool = False, +) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]]: + """Computes masked flash attention using block-sparse masking. + + Args: + q: Query tensor with shape (batch_size, num_kv_heads, + num_q_heads_per_kv_head, q_seq_len, head_dim). + k: Key tensor with shape (batch_size, num_kv_heads, kv_seq_len, head_dim). + v: Value tensor with shape (batch_size, num_kv_heads, kv_seq_len, + v_head_dim). + segment_ids: SegmentIds are a mechanism to ensure that there is no + cross-attention between segments (fraction of a sequence) that have been + concatenated together into a sequence. Each array is a list of ids + (integers). Only tokens with the same id are allowed to attend to each + other. It stores the segment ids of the query and key/value sequences. + block_kv: Block size for the key/value sequence dimension. + block_q: Block size for the query sequence dimension. + mask: The full attention mask. Shape (q_seq_len, kv_seq_len). + mask_value: The value to use for masked-out attention scores. + cap: Optional cap for attention logits. + save_residuals: Whether to save residuals. If True, returns a tuple of + (output, dict=(logsumexp, max_logits)). + + Returns: + A tuple containing: + - The output of the attention computation. + - A dict of (logsumexp, max_logits) + """ + batch_size, num_q_heads, q_seq_len, qk_head_dim_size = q.shape + _, num_kv_heads, kv_seq_len, _ = k.shape + v_head_dim_size = v.shape[-1] + data_type = q.dtype + q_groups = num_q_heads // num_kv_heads + q = q.reshape(( + batch_size, + num_kv_heads, + q_groups, + q_seq_len, + qk_head_dim_size, + )) + + tc = kv_seq_len // block_kv + tr = q_seq_len // block_q + + mask_full = jnp.broadcast_to( + mask[None, :, :], (batch_size, q_seq_len, kv_seq_len) + ) + + if segment_ids is not None: + segment_ids_q = segment_ids.q[:, :, None] + segment_ids_kv = segment_ids.kv[:, None, :] + mask_full = jnp.logical_and( + mask_full, segment_ids_q == segment_ids_kv + ) + mask_blocked = jax.jit(mask_blocker, static_argnums=[1, 2])( + mask_full, block_q, block_kv + ) + + l = jnp.zeros( + (batch_size, num_kv_heads, q_groups, q_seq_len), dtype=jnp.float32 + ) + m = jnp.full( + (batch_size, num_kv_heads, q_groups, q_seq_len), + -jnp.inf, + dtype=jnp.float32, + ) + + output = jnp.zeros( + ( + batch_size, + num_kv_heads, + q_groups, + q_seq_len, + v_head_dim_size, + ), + dtype=data_type, + ) + + def outer_loop_body(j, carried): + output, l, m = carried + k_j_slice = jax.lax.dynamic_slice_in_dim(k, j * block_kv, block_kv, axis=-2) + v_j_slice = jax.lax.dynamic_slice_in_dim(v, j * block_kv, block_kv, axis=-2) + + def inner_loop_body(i, carried_inner): + output, l, m = carried_inner + + # this assumes default mask value, + def _true_fn(output, l, m): + # let's get the slice of Q in N dimension + q_slice = jax.lax.dynamic_slice_in_dim(q, i * block_q, block_q, axis=-2) + output_i_slice = jax.lax.dynamic_slice_in_dim( + output, i * block_q, block_q, axis=-2 + ) + l_i_slice = jax.lax.dynamic_slice_in_dim( + l, i * block_q, block_q, axis=-1 + ) + m_i_slice = jax.lax.dynamic_slice_in_dim( + m, i * block_q, block_q, axis=-1 + ) + s_i_j = jnp.einsum( + "bxhqc,bxkc->bxhqk", + q_slice, + k_j_slice, + preferred_element_type=jnp.float32, + ) + full_mask_i_j_slice = jax.lax.dynamic_slice( + mask_full, + (0, i * block_q, j * block_kv), + (batch_size, block_q, block_kv), + ) + broadcasted_mask = jnp.broadcast_to( + full_mask_i_j_slice[:, None, None, :, :], + (batch_size, num_kv_heads, q_groups, block_q, block_kv), + ) + + s_i_j = jnp.where(broadcasted_mask, s_i_j, mask_value) + if cap is not None: + s_i_j = jnp.tanh(s_i_j / cap) + s_i_j = s_i_j * cap + m_i_j = s_i_j.max(axis=-1) + p_i_j = jnp.exp(s_i_j - m_i_j[..., None]) + l_i_j = p_i_j.sum(axis=-1) + assert m_i_j.shape == m_i_slice.shape + m_i_new = jnp.maximum(m_i_slice, m_i_j) + m_i_difference = jnp.exp(m_i_slice - m_i_new) + m_i_j_difference = jnp.exp(m_i_j - m_i_new) + l_i_new = m_i_difference * l_i_slice + m_i_j_difference * l_i_j + + divider = l_i_new[..., None] + numerator = l_i_slice[..., None] * m_i_difference[ + ..., None + ] * output_i_slice + m_i_j_difference[..., None] * jnp.einsum( + "bxhqk,bxkc->bxhqc", + p_i_j, + v_j_slice, + preferred_element_type=data_type, + ) + + output_i_slice_new = numerator / divider + output = jax.lax.dynamic_update_index_in_dim( + output, output_i_slice_new.astype(data_type), i * block_q, axis=-2 + ) + l = jax.lax.dynamic_update_index_in_dim( + l, l_i_new, i * block_q, axis=-1 + ) + m = jax.lax.dynamic_update_index_in_dim( + m, m_i_new, i * block_q, axis=-1 + ) + return output, l, m + + def _false_fn(output, l, m): + """Dummy function.""" + + return output, l, m + + batch_size = mask_blocked.shape[0] + mask_i_j_slice = jax.lax.dynamic_slice( + mask_blocked, (0, i, j), (batch_size, 1, 1) + ) + # The _true_fn should be executed if at least one element in the slice is non-zero, + # meaning at least one batch requires computation for this block. + output, l, m = jax.lax.cond( + jnp.any(jnp.not_equal(mask_i_j_slice, 0)), + _true_fn, + _false_fn, + output, + l, + m, + ) + + return output, l, m + + output, l, m = jax.lax.fori_loop( + 0, tr, inner_loop_body, (output, l, m), unroll=True + ) + + return (output, l, m) + + output, l, m = jax.lax.fori_loop( + 0, tc, outer_loop_body, (output, l, m), unroll=True + ) + + # Reshape the output to drop the size one dimension at index 2, + # which corresponds to `num_q_heads // num_kv_heads` when num_q_heads == num_kv_heads. + output = output.squeeze(axis=2) + if not save_residuals: + return output + + l = l.squeeze(axis=2) + m = m.squeeze(axis=2) + stats = {"logsumexp": m + jnp.log(l), "max_logits": m} + stats = jax.tree.map(jax.lax.stop_gradient, stats) + return output, stats + + +def mask_blocker(mask: jnp.ndarray, block_q: int, block_kv: int) -> jnp.ndarray: + """Creates a blocked mask from a full mask. + + Args: + mask: The full attention mask. + block_q: Block size for the query sequence dimension. + block_kv: Block size for the key/value sequence dimension. + + Returns: + A blocked mask where each element indicates the number of non-zero + elements in the corresponding block of the original mask. + """ + if mask.ndim == 3: + batch_size, q_seq_len, kv_seq_len = mask.shape + has_batch = True + elif mask.ndim == 2: + q_seq_len, kv_seq_len = mask.shape + has_batch = False + else: + raise ValueError(f"mask must have 2 or 3 dimensions, got {mask.ndim}") + + if q_seq_len % block_q != 0: + raise ValueError( + f"q_seq_len {q_seq_len} must be divisible by block_q {block_q}" + ) + if kv_seq_len % block_kv != 0: + raise ValueError( + f"kv_seq_len {kv_seq_len} must be divisible by block_kv {block_kv}" + ) + q_blocks = q_seq_len // block_q + kv_blocks = kv_seq_len // block_kv + + if has_batch: + blocked_mask = mask.reshape( + batch_size, q_blocks, block_q, kv_blocks, block_kv + ) + return jnp.count_nonzero(blocked_mask, axis=(2, 4)).astype(jnp.int32) + else: + blocked_mask = mask.reshape(q_blocks, block_q, kv_blocks, block_kv) + return jnp.count_nonzero(blocked_mask, axis=(1, 3)).astype(jnp.int32) diff --git a/src/MaxText/layers/attention_op.py b/src/MaxText/layers/attention_op.py index c8924e608..298209ada 100644 --- a/src/MaxText/layers/attention_op.py +++ b/src/MaxText/layers/attention_op.py @@ -15,69 +15,59 @@ """Attentions Ops Layers.""" import dataclasses import functools -from typing import Any, Callable, Optional, Tuple from functools import partial import math +from typing import Any, Callable, Optional, Tuple -import numpy as np - +from flax import linen as nn +from flax import nnx +from flax.linen import partitioning import jax from jax import lax from jax.ad_checkpoint import checkpoint_name +from jax.experimental import pallas as pl from jax.experimental.pallas.ops.gpu import attention as gpu_pallas_attention from jax.experimental.pallas.ops.gpu import decode_attention as gpu_pallas_decode_attention -from jax.experimental import pallas as pl -from jax.sharding import Mesh, NamedSharding -import jax.numpy as jnp - from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask - -from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel -from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_mask - - -from flax import linen as nn -from flax import nnx -from flax.linen import partitioning - - +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding from MaxText import max_utils -from MaxText.sharding import maybe_shard_with_name from MaxText.common_types import ( - DEFAULT_MASK_VALUE, + Array, + AttentionType, + AxisIdxes, + AxisNames, BATCH, BATCH_NO_EXP, - HEAD, - KV_LENGTH, - PREFILL_LENGTH, - D_KV, - CACHE_BATCH_PREFILL, - CACHE_SEQUENCE, - AxisNames, CACHE_BATCH, + CACHE_BATCH_PREFILL, CACHE_HEADS, - CACHE_SCALE_BATCH, CACHE_KV, - CACHE_SCALE_SEQUENCE, + CACHE_SCALE_BATCH, CACHE_SCALE_HEADS, CACHE_SCALE_KV, - AxisIdxes, - LENGTH, - LENGTH_NO_EXP, - DType, + CACHE_SCALE_SEQUENCE, + CACHE_SEQUENCE, Config, - Array, - Q_LENGTH, - Q_LENGTH_NO_EXP, - DECODE_LENGTH, DECODE_BATCH, - MODEL_MODE_AUTOREGRESSIVE, + DECODE_LENGTH, DECODING_ACTIVE_SEQUENCE_INDICATOR, - MODEL_MODE_TRAIN, - MODEL_MODE_PREFILL, + DEFAULT_MASK_VALUE, + DType, + D_KV, EP_AS_CONTEXT, - AttentionType, + EP_AS_FSDP, + HEAD, + KV_LENGTH, + LENGTH, + LENGTH_NO_EXP, + MODEL_MODE_AUTOREGRESSIVE, + MODEL_MODE_PREFILL, + MODEL_MODE_TRAIN, + PREFILL_LENGTH, + Q_LENGTH, + Q_LENGTH_NO_EXP, ) from MaxText.inference import page_manager from MaxText.inference.kvcache import KVQuant, KVTensor @@ -86,6 +76,11 @@ from MaxText.layers import nnx_wrappers from MaxText.layers.initializers import variable_to_logically_partitioned from MaxText.layers.quantizations import AqtQuantization as Quant +from MaxText.sharding import maybe_shard_with_name +from MaxText.kernels import jax_flash_attention +import numpy as np +from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel +from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_mask # pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes # pytype: disable=attribute-error @@ -1200,6 +1195,17 @@ def wrap_splash_kernel(single_head_mask, shard_head_size=1): segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH,)) else: segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH_NO_EXP,)) + elif ( + self.config.use_jax_splash + and self.config.expert_shard_attention_option == EP_AS_FSDP + ): + if self.config.use_max_logit_estimate > 0: + sa_config = dataclasses.replace( + sa_config, max_logit_const=self.config.use_max_logit_estimate + ) + segment_axis_names_splash_kernel = nn.logical_to_mesh_axes(( + Q_LENGTH_NO_EXP, + )) else: # Create multi-head mask multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) @@ -1295,6 +1301,18 @@ def wrap_flash_attention( attention_output = jax.vmap(lambda q, k, v, d, s: kernel(q, k, v, d, sinks=s), in_axes=(0, 0, 0, 0, None))( query, key, value, decoder_segment_ids_tuple, sinks ) + elif self.config.use_jax_splash: + materialized_mask = jnp.asarray(mask[:, :]) + attention_output = jax_flash_attention.flash_attention_block_masked( + query, + key, + value, + decoder_segment_ids_tuple, + block_kv=self.config.sa_block_kv, + block_q=self.config.sa_block_q, + mask=materialized_mask, + mask_value=DEFAULT_MASK_VALUE, + ) else: attention_output = jax.vmap(splash_kernel, in_axes=(0, 0, 0, 0, None))( query, key, value, decoder_segment_ids_tuple, sinks @@ -1321,7 +1339,7 @@ def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None): value, decoder_segment_ids_q, decoder_segment_ids_kv, - splash_kernel, + None if self.config.use_jax_splash else splash_kernel, cp_size, load_balanced_context_parallel, sinks, diff --git a/tests/attention_test.py b/tests/attention_test.py index 1d7c755d0..020076bf8 100644 --- a/tests/attention_test.py +++ b/tests/attention_test.py @@ -14,6 +14,7 @@ """Tests for Attentions.""" +from maxtext.tests import attention_test_util import itertools import os.path import random @@ -1123,114 +1124,9 @@ def test_forward_serve_vllm(self, mock_sharded_ragged_paged_attention): self.assertEqual(output.shape, (self.global_batch_size, seq_len, self.embed_dim)) -class MLATest(parameterized.TestCase): +class MLATest(attention_test_util.MLATestBase): """Test for the Multi-Headed Latent Attention""" - config_arguments = { - "per_device_batch_size": 1.0, - "run_name": "test", - "enable_checkpointing": False, - "max_target_length": 128, - "max_prefill_predict_length": 16, - "attention_type": AttentionType.MLA.value, - "head_dim": 192, - "q_lora_rank": 10, - "kv_lora_rank": 20, - "qk_nope_head_dim": 128, - "qk_rope_head_dim": 64, - "v_head_dim": 192, - } - - def setUp(self): - """Initializes the configuration for each test""" - super().setUp() - jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) - config = pyconfig.initialize( - [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], - **self.config_arguments, - ) - self.cfg = config - self.rng = jax.random.PRNGKey(0) - self.nnx_rng = nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)) - devices_array = maxtext_utils.create_device_mesh(self.cfg) - self.mesh = Mesh(devices_array, self.cfg.mesh_axes) - - def init_mla(self, config_arguments, rope_type): - """Helper function to initialize MLA with different model names.""" - cfg = pyconfig.initialize( - [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], - **config_arguments, - rope_type=rope_type, - ) - - devices_array = maxtext_utils.create_device_mesh(cfg) - mesh = Mesh(devices_array, cfg.mesh_axes) - - dummy_inputs_q = jnp.ones((cfg.global_batch_size_to_train_on, cfg.max_target_length, cfg.base_emb_dim)) - dummy_inputs_kv = jnp.ones((cfg.global_batch_size_to_train_on, cfg.max_target_length, cfg.base_emb_dim)) - - mla = MLA( - config=cfg, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - mesh=mesh, - attention_kernel="dot_product", - dtype=cfg.dtype, - dropout_rate=cfg.dropout_rate, - attention_type=cfg.attention_type, - q_lora_rank=cfg.q_lora_rank, - kv_lora_rank=cfg.kv_lora_rank, - qk_nope_head_dim=cfg.qk_nope_head_dim, - qk_rope_head_dim=cfg.qk_rope_head_dim, - v_head_dim=cfg.v_head_dim, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) - - return cfg, mla - - def get_data(self, cfg, dtype): - """get data""" - lnx = jax.random.normal( - self.rng, - shape=(cfg.global_batch_size_to_train_on, cfg.max_target_length, cfg.base_emb_dim), - dtype=dtype, - ) - - decoder_segment_ids = jax.random.randint(self.rng, (cfg.global_batch_size_to_train_on, cfg.max_target_length), 0, 4) - decoder_positions = jax.random.randint( - self.rng, (cfg.global_batch_size_to_train_on, cfg.max_target_length), 0, cfg.max_target_length - ) - - return lnx, decoder_segment_ids, decoder_positions - - def get_structured_data(self, cfg, dtype): - """get structured data""" - lnx = jax.random.normal( - self.rng, - shape=( - cfg.global_batch_size_to_train_on, - cfg.max_target_length, - cfg.base_emb_dim, - ), - dtype=dtype, - ) - - decoder_positions = jnp.stack( - [jnp.arange(cfg.max_target_length, dtype=jnp.int32) for _ in range(cfg.global_batch_size_to_train_on)] - ) - - decoder_segment_ids = ( - jax.numpy.zeros((cfg.global_batch_size_to_train_on, cfg.max_target_length)) + DECODING_ACTIVE_SEQUENCE_INDICATOR - ) - - return lnx, decoder_segment_ids, decoder_positions - @parameterized.named_parameters( {"testcase_name": "RoPE_Yarn_Autoregression", "rope_type": "yarn"}, {"testcase_name": "Default_Autoregression", "rope_type": "default"}, @@ -1463,8 +1359,13 @@ def test_tpu_flash_attention_context_parallel( rngs=self.nnx_rng, ) nnx.update(attention_as_mla_flash_cp, generic_state) - mla_generic_flash_cp_output = _forward_with_context_expert_parallelism( - cfg_cp, mesh_cp, attention_as_mla_flash_cp, lnx, decoder_segment_ids, decoder_positions + mla_generic_flash_cp_output = self.forward_with_context_expert_parallelism( + cfg_cp, + mesh_cp, + attention_as_mla_flash_cp, + lnx, + decoder_segment_ids, + decoder_positions, ) self.assertTrue( @@ -1475,51 +1376,5 @@ def test_tpu_flash_attention_context_parallel( ) -def _forward_with_context_expert_parallelism(cfg_cp, mesh_cp, attention_cp, lnx, decoder_segment_ids, decoder_positions): - """Get logits from attention under context/expert parallelism.""" - # If load balanced cp, shuffle along seq dim for input - # This corresponds to the pre-shuffle step in training - context_parallel_size = cfg_cp.context_parallel_size - if context_parallel_size > 1 and cfg_cp.context_parallel_load_balance: - batch = {"inputs": lnx, "inputs_segmentation": decoder_segment_ids, "inputs_position": decoder_positions} - with mesh_cp: - reordered_batch = maxtext_utils.get_reorder_callable(context_parallel_size, ShardMode.AUTO)(batch) - lnx = reordered_batch["inputs"] - decoder_segment_ids = reordered_batch["inputs_segmentation"] - decoder_positions = reordered_batch["inputs_position"] - # apply attention with sharding - with mesh_cp, nn_partitioning.axis_rules(cfg_cp.logical_axis_rules): - lnx_spec = nn_partitioning.logical_to_mesh_axes( - ("activation_batch_no_exp", "activation_length_no_exp", "activation_embed"), nn_partitioning.get_axis_rules() - ) - pos_spec = nn_partitioning.logical_to_mesh_axes( - ("activation_batch_no_exp", "activation_length_no_exp"), nn_partitioning.get_axis_rules() - ) - lnx_sharding = NamedSharding(mesh_cp, lnx_spec) - pos_sharding = NamedSharding(mesh_cp, pos_spec) - - lnx = jax.device_put(lnx, lnx_sharding) - decoder_segment_ids = jax.device_put(decoder_segment_ids, pos_sharding) - decoder_positions = jax.device_put(decoder_positions, pos_sharding) - - attention_cp_output, _ = attention_cp( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) - attention_cp_output = attention_cp_output[0] if isinstance(attention_cp_output, tuple) else attention_cp_output - - # If load balanced cp, de-shuffle and gather along seq dim for output - # Note training does not need post-shuffle. Since the target seq is also pre-shuffled, the loss remains correct - if context_parallel_size > 1 and cfg_cp.context_parallel_load_balance: - attention_cp_output = max_utils.reorder_sequence( - tensor=attention_cp_output, cp_size=context_parallel_size, seq_dim=1, to_contiguous=True - ) - return attention_cp_output - - if __name__ == "__main__": unittest.main() diff --git a/tests/attention_test_util.py b/tests/attention_test_util.py new file mode 100644 index 000000000..f761e3511 --- /dev/null +++ b/tests/attention_test_util.py @@ -0,0 +1,236 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test util for attention tests.""" + +import os +import sys + +from absl.testing import parameterized +from flax import nnx +from flax.linen import partitioning as nn_partitioning +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding +from MaxText import max_utils +from MaxText import maxtext_utils +from MaxText import pyconfig +from MaxText.common_types import AttentionType, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, ShardMode +from MaxText.globals import MAXTEXT_PKG_DIR +from MaxText.layers.attention_mla import MLA + + +class MLATestBase(parameterized.TestCase): + """Test base for MLATest.""" + + config_arguments = { + "per_device_batch_size": 1.0, + "run_name": "test", + "enable_checkpointing": False, + "max_target_length": 128, + "max_prefill_predict_length": 16, + "attention_type": AttentionType.MLA.value, + "head_dim": 192, + "q_lora_rank": 10, + "kv_lora_rank": 20, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 192, + } + + def setUp(self): + """Initializes the configuration for each test""" + super().setUp() + jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) + config = pyconfig.initialize( + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + **self.config_arguments, + ) + self.cfg = config + self.rng = jax.random.PRNGKey(0) + self.nnx_rng = nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)) + devices_array = maxtext_utils.create_device_mesh(self.cfg) + self.mesh = Mesh(devices_array, self.cfg.mesh_axes) + + def init_mla(self, config_arguments, rope_type): + """Helper function to initialize MLA with different model names.""" + cfg = pyconfig.initialize( + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + **config_arguments, + rope_type=rope_type, + ) + + devices_array = maxtext_utils.create_device_mesh(cfg) + mesh = Mesh(devices_array, cfg.mesh_axes) + + dummy_inputs_q = jnp.ones(( + cfg.global_batch_size_to_train_on, + cfg.max_target_length, + cfg.base_emb_dim, + )) + dummy_inputs_kv = jnp.ones(( + cfg.global_batch_size_to_train_on, + cfg.max_target_length, + cfg.base_emb_dim, + )) + + mla = MLA( + config=cfg, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + mesh=mesh, + attention_kernel="dot_product", + dtype=cfg.dtype, + dropout_rate=cfg.dropout_rate, + attention_type=cfg.attention_type, + q_lora_rank=cfg.q_lora_rank, + kv_lora_rank=cfg.kv_lora_rank, + qk_nope_head_dim=cfg.qk_nope_head_dim, + qk_rope_head_dim=cfg.qk_rope_head_dim, + v_head_dim=cfg.v_head_dim, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) + + return cfg, mla + + def get_data(self, cfg, dtype): + """get data""" + lnx = jax.random.normal( + self.rng, + shape=( + cfg.global_batch_size_to_train_on, + cfg.max_target_length, + cfg.base_emb_dim, + ), + dtype=dtype, + ) + + decoder_segment_ids = jax.random.randint( + self.rng, + (cfg.global_batch_size_to_train_on, cfg.max_target_length), + 0, + 4, + ) + # decoder_segment_ids = None + decoder_positions = jax.random.randint( + self.rng, + (cfg.global_batch_size_to_train_on, cfg.max_target_length), + 0, + cfg.max_target_length, + ) + + return lnx, decoder_segment_ids, decoder_positions + + def get_structured_data(self, cfg, dtype): + """get structured data""" + lnx = jax.random.normal( + self.rng, + shape=( + cfg.global_batch_size_to_train_on, + cfg.max_target_length, + cfg.base_emb_dim, + ), + dtype=dtype, + ) + + decoder_positions = jnp.stack([ + jnp.arange(cfg.max_target_length, dtype=jnp.int32) + for _ in range(cfg.global_batch_size_to_train_on) + ]) + + decoder_segment_ids = ( + jax.numpy.zeros( + (cfg.global_batch_size_to_train_on, cfg.max_target_length) + ) + + DECODING_ACTIVE_SEQUENCE_INDICATOR + ) + + return lnx, decoder_segment_ids, decoder_positions + + def forward_with_context_expert_parallelism( + self, + cfg_cp, + mesh_cp, + attention_cp, + lnx, + decoder_segment_ids, + decoder_positions, + ): + """Get logits from attention under context/expert parallelism.""" + # If load balanced cp, shuffle along seq dim for input + # This corresponds to the pre-shuffle step in training + context_parallel_size = cfg_cp.context_parallel_size + if context_parallel_size > 1 and cfg_cp.context_parallel_load_balance: + batch = { + "inputs": lnx, + "inputs_segmentation": decoder_segment_ids, + "inputs_position": decoder_positions, + } + with mesh_cp: + reordered_batch = maxtext_utils.get_reorder_callable( + context_parallel_size, ShardMode.AUTO + )(batch) + lnx = reordered_batch["inputs"] + decoder_segment_ids = reordered_batch["inputs_segmentation"] + decoder_positions = reordered_batch["inputs_position"] + # apply attention with sharding + with mesh_cp, nn_partitioning.axis_rules(cfg_cp.logical_axis_rules): + lnx_spec = nn_partitioning.logical_to_mesh_axes( + ( + "activation_batch_no_exp", + "activation_length_no_exp", + "activation_embed", + ), + nn_partitioning.get_axis_rules(), + ) + pos_spec = nn_partitioning.logical_to_mesh_axes( + ("activation_batch_no_exp", "activation_length_no_exp"), + nn_partitioning.get_axis_rules(), + ) + lnx_sharding = NamedSharding(mesh_cp, lnx_spec) + pos_sharding = NamedSharding(mesh_cp, pos_spec) + + lnx = jax.device_put(lnx, lnx_sharding) + decoder_segment_ids = jax.device_put(decoder_segment_ids, pos_sharding) + decoder_positions = jax.device_put(decoder_positions, pos_sharding) + + attention_cp_output, _ = attention_cp( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + attention_cp_output = ( + attention_cp_output[0] + if isinstance(attention_cp_output, tuple) + else attention_cp_output + ) + + # If load balanced cp, de-shuffle and gather along seq dim for output + # Note training does not need post-shuffle. Since the target seq is also pre-shuffled, the loss remains correct + if context_parallel_size > 1 and cfg_cp.context_parallel_load_balance: + attention_cp_output = max_utils.reorder_sequence( + tensor=attention_cp_output, + cp_size=context_parallel_size, + seq_dim=1, + to_contiguous=True, + ) + return attention_cp_output