diff --git a/src/MaxText/common_types.py b/src/MaxText/common_types.py index f26d02cb1..f2d630c7d 100644 --- a/src/MaxText/common_types.py +++ b/src/MaxText/common_types.py @@ -65,6 +65,7 @@ # expert_shard_attention_option EP_AS_CONTEXT = "context" +EP_AS_FSDP = "fsdp" DECODING_ACTIVE_SEQUENCE_INDICATOR = 1 diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index e949f8156..c0b9c9302 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -967,3 +967,4 @@ partial_rotary_factor: 1.0 # Use tokamax library for gmm kernel implementation use_tokamax_gmm: false use_tokamax_splash: false +use_jax_splash: false diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 9fd1bbf3b..8b6a1856a 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -451,6 +451,9 @@ class Attention(BaseModel): ragged_block_size: int = Field(256, description="Block size for ragged attention.") enable_padding_causal_mask: bool = Field(True, description="Temporary flag for TE padding.") use_tokamax_splash: bool = Field(False, description="Whether to use tokamax splash attention.") + use_jax_splash: bool = Field( + False, description="Whether to use jax splash attention." + ) class MoBa(BaseModel): diff --git a/src/MaxText/kernels/jax_flash_attention.py b/src/MaxText/kernels/jax_flash_attention.py new file mode 100644 index 000000000..25c7ebd40 --- /dev/null +++ b/src/MaxText/kernels/jax_flash_attention.py @@ -0,0 +1,269 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX implementation of Flash Attention.""" + +from typing import Optional, Tuple, Union + +import jax +import jax.numpy as jnp +from MaxText.kernels import splash_attention_kernel + +SegmentIds = splash_attention_kernel.SegmentIds + + +def flash_attention_block_masked( + q: jnp.ndarray, + k: jnp.ndarray, + v: jnp.ndarray, + segment_ids: SegmentIds | None, + block_kv: int, + block_q: int, + mask: jnp.ndarray, + mask_value: float, + cap: Optional[float] = None, + save_residuals: bool = False, +) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]]: + """Computes masked flash attention using block-sparse masking. + + Args: + q: Query tensor with shape (batch_size, num_kv_heads, + num_q_heads_per_kv_head, q_seq_len, head_dim). + k: Key tensor with shape (batch_size, num_kv_heads, kv_seq_len, head_dim). + v: Value tensor with shape (batch_size, num_kv_heads, kv_seq_len, + v_head_dim). + segment_ids: SegmentIds are a mechanism to ensure that there is no + cross-attention between segments (fraction of a sequence) that have been + concatenated together into a sequence. Each array is a list of ids + (integers). Only tokens with the same id are allowed to attend to each + other. It stores the segment ids of the query and key/value sequences. + block_kv: Block size for the key/value sequence dimension. + block_q: Block size for the query sequence dimension. + mask: The full attention mask. Shape (q_seq_len, kv_seq_len). + mask_value: The value to use for masked-out attention scores. + cap: Optional cap for attention logits. + save_residuals: Whether to save residuals. If True, returns a tuple of + (output, dict=(logsumexp, max_logits)). + + Returns: + A tuple containing: + - The output of the attention computation. + - A dict of (logsumexp, max_logits) + """ + batch_size, num_q_heads, q_seq_len, qk_head_dim_size = q.shape + _, num_kv_heads, kv_seq_len, _ = k.shape + v_head_dim_size = v.shape[-1] + data_type = q.dtype + q_groups = num_q_heads // num_kv_heads + q = q.reshape(( + batch_size, + num_kv_heads, + q_groups, + q_seq_len, + qk_head_dim_size, + )) + + tc = kv_seq_len // block_kv + tr = q_seq_len // block_q + + mask_full = jnp.broadcast_to( + mask[None, :, :], (batch_size, q_seq_len, kv_seq_len) + ) + + if segment_ids is not None: + segment_ids_q = segment_ids.q[:, :, None] + segment_ids_kv = segment_ids.kv[:, None, :] + mask_full = jnp.logical_and( + mask_full, segment_ids_q == segment_ids_kv + ) + mask_blocked = jax.jit(mask_blocker, static_argnums=[1, 2])( + mask_full, block_q, block_kv + ) + + l = jnp.zeros( + (batch_size, num_kv_heads, q_groups, q_seq_len), dtype=jnp.float32 + ) + m = jnp.full( + (batch_size, num_kv_heads, q_groups, q_seq_len), + -jnp.inf, + dtype=jnp.float32, + ) + + output = jnp.zeros( + ( + batch_size, + num_kv_heads, + q_groups, + q_seq_len, + v_head_dim_size, + ), + dtype=data_type, + ) + + def outer_loop_body(j, carried): + output, l, m = carried + k_j_slice = jax.lax.dynamic_slice_in_dim(k, j * block_kv, block_kv, axis=-2) + v_j_slice = jax.lax.dynamic_slice_in_dim(v, j * block_kv, block_kv, axis=-2) + + def inner_loop_body(i, carried_inner): + output, l, m = carried_inner + + # this assumes default mask value, + def _true_fn(output, l, m): + # let's get the slice of Q in N dimension + q_slice = jax.lax.dynamic_slice_in_dim(q, i * block_q, block_q, axis=-2) + output_i_slice = jax.lax.dynamic_slice_in_dim( + output, i * block_q, block_q, axis=-2 + ) + l_i_slice = jax.lax.dynamic_slice_in_dim( + l, i * block_q, block_q, axis=-1 + ) + m_i_slice = jax.lax.dynamic_slice_in_dim( + m, i * block_q, block_q, axis=-1 + ) + s_i_j = jnp.einsum( + "bxhqc,bxkc->bxhqk", + q_slice, + k_j_slice, + preferred_element_type=jnp.float32, + ) + full_mask_i_j_slice = jax.lax.dynamic_slice( + mask_full, + (0, i * block_q, j * block_kv), + (batch_size, block_q, block_kv), + ) + broadcasted_mask = jnp.broadcast_to( + full_mask_i_j_slice[:, None, None, :, :], + (batch_size, num_kv_heads, q_groups, block_q, block_kv), + ) + + s_i_j = jnp.where(broadcasted_mask, s_i_j, mask_value) + if cap is not None: + s_i_j = jnp.tanh(s_i_j / cap) + s_i_j = s_i_j * cap + m_i_j = s_i_j.max(axis=-1) + p_i_j = jnp.exp(s_i_j - m_i_j[..., None]) + l_i_j = p_i_j.sum(axis=-1) + assert m_i_j.shape == m_i_slice.shape + m_i_new = jnp.maximum(m_i_slice, m_i_j) + m_i_difference = jnp.exp(m_i_slice - m_i_new) + m_i_j_difference = jnp.exp(m_i_j - m_i_new) + l_i_new = m_i_difference * l_i_slice + m_i_j_difference * l_i_j + + divider = l_i_new[..., None] + numerator = l_i_slice[..., None] * m_i_difference[ + ..., None + ] * output_i_slice + m_i_j_difference[..., None] * jnp.einsum( + "bxhqk,bxkc->bxhqc", + p_i_j, + v_j_slice, + preferred_element_type=data_type, + ) + + output_i_slice_new = numerator / divider + output = jax.lax.dynamic_update_index_in_dim( + output, output_i_slice_new.astype(data_type), i * block_q, axis=-2 + ) + l = jax.lax.dynamic_update_index_in_dim( + l, l_i_new, i * block_q, axis=-1 + ) + m = jax.lax.dynamic_update_index_in_dim( + m, m_i_new, i * block_q, axis=-1 + ) + return output, l, m + + def _false_fn(output, l, m): + """Dummy function.""" + + return output, l, m + + batch_size = mask_blocked.shape[0] + mask_i_j_slice = jax.lax.dynamic_slice( + mask_blocked, (0, i, j), (batch_size, 1, 1) + ) + # The _true_fn should be executed if at least one element in the slice is non-zero, + # meaning at least one batch requires computation for this block. + output, l, m = jax.lax.cond( + jnp.any(jnp.not_equal(mask_i_j_slice, 0)), + _true_fn, + _false_fn, + output, + l, + m, + ) + + return output, l, m + + output, l, m = jax.lax.fori_loop( + 0, tr, inner_loop_body, (output, l, m), unroll=True + ) + + return (output, l, m) + + output, l, m = jax.lax.fori_loop( + 0, tc, outer_loop_body, (output, l, m), unroll=True + ) + + # Reshape the output to drop the size one dimension at index 2, + # which corresponds to `num_q_heads // num_kv_heads` when num_q_heads == num_kv_heads. + output = output.squeeze(axis=2) + if not save_residuals: + return output + + l = l.squeeze(axis=2) + m = m.squeeze(axis=2) + stats = {"logsumexp": m + jnp.log(l), "max_logits": m} + stats = jax.tree.map(jax.lax.stop_gradient, stats) + return output, stats + + +def mask_blocker(mask: jnp.ndarray, block_q: int, block_kv: int) -> jnp.ndarray: + """Creates a blocked mask from a full mask. + + Args: + mask: The full attention mask. + block_q: Block size for the query sequence dimension. + block_kv: Block size for the key/value sequence dimension. + + Returns: + A blocked mask where each element indicates the number of non-zero + elements in the corresponding block of the original mask. + """ + if mask.ndim == 3: + batch_size, q_seq_len, kv_seq_len = mask.shape + has_batch = True + elif mask.ndim == 2: + q_seq_len, kv_seq_len = mask.shape + has_batch = False + else: + raise ValueError(f"mask must have 2 or 3 dimensions, got {mask.ndim}") + + if q_seq_len % block_q != 0: + raise ValueError( + f"q_seq_len {q_seq_len} must be divisible by block_q {block_q}" + ) + if kv_seq_len % block_kv != 0: + raise ValueError( + f"kv_seq_len {kv_seq_len} must be divisible by block_kv {block_kv}" + ) + q_blocks = q_seq_len // block_q + kv_blocks = kv_seq_len // block_kv + + if has_batch: + blocked_mask = mask.reshape( + batch_size, q_blocks, block_q, kv_blocks, block_kv + ) + return jnp.count_nonzero(blocked_mask, axis=(2, 4)).astype(jnp.int32) + else: + blocked_mask = mask.reshape(q_blocks, block_q, kv_blocks, block_kv) + return jnp.count_nonzero(blocked_mask, axis=(1, 3)).astype(jnp.int32) diff --git a/src/MaxText/layers/attention_op.py b/src/MaxText/layers/attention_op.py index c8924e608..298209ada 100644 --- a/src/MaxText/layers/attention_op.py +++ b/src/MaxText/layers/attention_op.py @@ -15,69 +15,59 @@ """Attentions Ops Layers.""" import dataclasses import functools -from typing import Any, Callable, Optional, Tuple from functools import partial import math +from typing import Any, Callable, Optional, Tuple -import numpy as np - +from flax import linen as nn +from flax import nnx +from flax.linen import partitioning import jax from jax import lax from jax.ad_checkpoint import checkpoint_name +from jax.experimental import pallas as pl from jax.experimental.pallas.ops.gpu import attention as gpu_pallas_attention from jax.experimental.pallas.ops.gpu import decode_attention as gpu_pallas_decode_attention -from jax.experimental import pallas as pl -from jax.sharding import Mesh, NamedSharding -import jax.numpy as jnp - from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask - -from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel -from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_mask - - -from flax import linen as nn -from flax import nnx -from flax.linen import partitioning - - +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding from MaxText import max_utils -from MaxText.sharding import maybe_shard_with_name from MaxText.common_types import ( - DEFAULT_MASK_VALUE, + Array, + AttentionType, + AxisIdxes, + AxisNames, BATCH, BATCH_NO_EXP, - HEAD, - KV_LENGTH, - PREFILL_LENGTH, - D_KV, - CACHE_BATCH_PREFILL, - CACHE_SEQUENCE, - AxisNames, CACHE_BATCH, + CACHE_BATCH_PREFILL, CACHE_HEADS, - CACHE_SCALE_BATCH, CACHE_KV, - CACHE_SCALE_SEQUENCE, + CACHE_SCALE_BATCH, CACHE_SCALE_HEADS, CACHE_SCALE_KV, - AxisIdxes, - LENGTH, - LENGTH_NO_EXP, - DType, + CACHE_SCALE_SEQUENCE, + CACHE_SEQUENCE, Config, - Array, - Q_LENGTH, - Q_LENGTH_NO_EXP, - DECODE_LENGTH, DECODE_BATCH, - MODEL_MODE_AUTOREGRESSIVE, + DECODE_LENGTH, DECODING_ACTIVE_SEQUENCE_INDICATOR, - MODEL_MODE_TRAIN, - MODEL_MODE_PREFILL, + DEFAULT_MASK_VALUE, + DType, + D_KV, EP_AS_CONTEXT, - AttentionType, + EP_AS_FSDP, + HEAD, + KV_LENGTH, + LENGTH, + LENGTH_NO_EXP, + MODEL_MODE_AUTOREGRESSIVE, + MODEL_MODE_PREFILL, + MODEL_MODE_TRAIN, + PREFILL_LENGTH, + Q_LENGTH, + Q_LENGTH_NO_EXP, ) from MaxText.inference import page_manager from MaxText.inference.kvcache import KVQuant, KVTensor @@ -86,6 +76,11 @@ from MaxText.layers import nnx_wrappers from MaxText.layers.initializers import variable_to_logically_partitioned from MaxText.layers.quantizations import AqtQuantization as Quant +from MaxText.sharding import maybe_shard_with_name +from MaxText.kernels import jax_flash_attention +import numpy as np +from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel +from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_mask # pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes # pytype: disable=attribute-error @@ -1200,6 +1195,17 @@ def wrap_splash_kernel(single_head_mask, shard_head_size=1): segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH,)) else: segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH_NO_EXP,)) + elif ( + self.config.use_jax_splash + and self.config.expert_shard_attention_option == EP_AS_FSDP + ): + if self.config.use_max_logit_estimate > 0: + sa_config = dataclasses.replace( + sa_config, max_logit_const=self.config.use_max_logit_estimate + ) + segment_axis_names_splash_kernel = nn.logical_to_mesh_axes(( + Q_LENGTH_NO_EXP, + )) else: # Create multi-head mask multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) @@ -1295,6 +1301,18 @@ def wrap_flash_attention( attention_output = jax.vmap(lambda q, k, v, d, s: kernel(q, k, v, d, sinks=s), in_axes=(0, 0, 0, 0, None))( query, key, value, decoder_segment_ids_tuple, sinks ) + elif self.config.use_jax_splash: + materialized_mask = jnp.asarray(mask[:, :]) + attention_output = jax_flash_attention.flash_attention_block_masked( + query, + key, + value, + decoder_segment_ids_tuple, + block_kv=self.config.sa_block_kv, + block_q=self.config.sa_block_q, + mask=materialized_mask, + mask_value=DEFAULT_MASK_VALUE, + ) else: attention_output = jax.vmap(splash_kernel, in_axes=(0, 0, 0, 0, None))( query, key, value, decoder_segment_ids_tuple, sinks @@ -1321,7 +1339,7 @@ def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None): value, decoder_segment_ids_q, decoder_segment_ids_kv, - splash_kernel, + None if self.config.use_jax_splash else splash_kernel, cp_size, load_balanced_context_parallel, sinks, diff --git a/tests/attention_test.py b/tests/attention_test.py index 1d7c755d0..020076bf8 100644 --- a/tests/attention_test.py +++ b/tests/attention_test.py @@ -14,6 +14,7 @@ """Tests for Attentions.""" +from maxtext.tests import attention_test_util import itertools import os.path import random @@ -1123,114 +1124,9 @@ def test_forward_serve_vllm(self, mock_sharded_ragged_paged_attention): self.assertEqual(output.shape, (self.global_batch_size, seq_len, self.embed_dim)) -class MLATest(parameterized.TestCase): +class MLATest(attention_test_util.MLATestBase): """Test for the Multi-Headed Latent Attention""" - config_arguments = { - "per_device_batch_size": 1.0, - "run_name": "test", - "enable_checkpointing": False, - "max_target_length": 128, - "max_prefill_predict_length": 16, - "attention_type": AttentionType.MLA.value, - "head_dim": 192, - "q_lora_rank": 10, - "kv_lora_rank": 20, - "qk_nope_head_dim": 128, - "qk_rope_head_dim": 64, - "v_head_dim": 192, - } - - def setUp(self): - """Initializes the configuration for each test""" - super().setUp() - jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) - config = pyconfig.initialize( - [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], - **self.config_arguments, - ) - self.cfg = config - self.rng = jax.random.PRNGKey(0) - self.nnx_rng = nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)) - devices_array = maxtext_utils.create_device_mesh(self.cfg) - self.mesh = Mesh(devices_array, self.cfg.mesh_axes) - - def init_mla(self, config_arguments, rope_type): - """Helper function to initialize MLA with different model names.""" - cfg = pyconfig.initialize( - [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], - **config_arguments, - rope_type=rope_type, - ) - - devices_array = maxtext_utils.create_device_mesh(cfg) - mesh = Mesh(devices_array, cfg.mesh_axes) - - dummy_inputs_q = jnp.ones((cfg.global_batch_size_to_train_on, cfg.max_target_length, cfg.base_emb_dim)) - dummy_inputs_kv = jnp.ones((cfg.global_batch_size_to_train_on, cfg.max_target_length, cfg.base_emb_dim)) - - mla = MLA( - config=cfg, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - mesh=mesh, - attention_kernel="dot_product", - dtype=cfg.dtype, - dropout_rate=cfg.dropout_rate, - attention_type=cfg.attention_type, - q_lora_rank=cfg.q_lora_rank, - kv_lora_rank=cfg.kv_lora_rank, - qk_nope_head_dim=cfg.qk_nope_head_dim, - qk_rope_head_dim=cfg.qk_rope_head_dim, - v_head_dim=cfg.v_head_dim, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) - - return cfg, mla - - def get_data(self, cfg, dtype): - """get data""" - lnx = jax.random.normal( - self.rng, - shape=(cfg.global_batch_size_to_train_on, cfg.max_target_length, cfg.base_emb_dim), - dtype=dtype, - ) - - decoder_segment_ids = jax.random.randint(self.rng, (cfg.global_batch_size_to_train_on, cfg.max_target_length), 0, 4) - decoder_positions = jax.random.randint( - self.rng, (cfg.global_batch_size_to_train_on, cfg.max_target_length), 0, cfg.max_target_length - ) - - return lnx, decoder_segment_ids, decoder_positions - - def get_structured_data(self, cfg, dtype): - """get structured data""" - lnx = jax.random.normal( - self.rng, - shape=( - cfg.global_batch_size_to_train_on, - cfg.max_target_length, - cfg.base_emb_dim, - ), - dtype=dtype, - ) - - decoder_positions = jnp.stack( - [jnp.arange(cfg.max_target_length, dtype=jnp.int32) for _ in range(cfg.global_batch_size_to_train_on)] - ) - - decoder_segment_ids = ( - jax.numpy.zeros((cfg.global_batch_size_to_train_on, cfg.max_target_length)) + DECODING_ACTIVE_SEQUENCE_INDICATOR - ) - - return lnx, decoder_segment_ids, decoder_positions - @parameterized.named_parameters( {"testcase_name": "RoPE_Yarn_Autoregression", "rope_type": "yarn"}, {"testcase_name": "Default_Autoregression", "rope_type": "default"}, @@ -1463,8 +1359,13 @@ def test_tpu_flash_attention_context_parallel( rngs=self.nnx_rng, ) nnx.update(attention_as_mla_flash_cp, generic_state) - mla_generic_flash_cp_output = _forward_with_context_expert_parallelism( - cfg_cp, mesh_cp, attention_as_mla_flash_cp, lnx, decoder_segment_ids, decoder_positions + mla_generic_flash_cp_output = self.forward_with_context_expert_parallelism( + cfg_cp, + mesh_cp, + attention_as_mla_flash_cp, + lnx, + decoder_segment_ids, + decoder_positions, ) self.assertTrue( @@ -1475,51 +1376,5 @@ def test_tpu_flash_attention_context_parallel( ) -def _forward_with_context_expert_parallelism(cfg_cp, mesh_cp, attention_cp, lnx, decoder_segment_ids, decoder_positions): - """Get logits from attention under context/expert parallelism.""" - # If load balanced cp, shuffle along seq dim for input - # This corresponds to the pre-shuffle step in training - context_parallel_size = cfg_cp.context_parallel_size - if context_parallel_size > 1 and cfg_cp.context_parallel_load_balance: - batch = {"inputs": lnx, "inputs_segmentation": decoder_segment_ids, "inputs_position": decoder_positions} - with mesh_cp: - reordered_batch = maxtext_utils.get_reorder_callable(context_parallel_size, ShardMode.AUTO)(batch) - lnx = reordered_batch["inputs"] - decoder_segment_ids = reordered_batch["inputs_segmentation"] - decoder_positions = reordered_batch["inputs_position"] - # apply attention with sharding - with mesh_cp, nn_partitioning.axis_rules(cfg_cp.logical_axis_rules): - lnx_spec = nn_partitioning.logical_to_mesh_axes( - ("activation_batch_no_exp", "activation_length_no_exp", "activation_embed"), nn_partitioning.get_axis_rules() - ) - pos_spec = nn_partitioning.logical_to_mesh_axes( - ("activation_batch_no_exp", "activation_length_no_exp"), nn_partitioning.get_axis_rules() - ) - lnx_sharding = NamedSharding(mesh_cp, lnx_spec) - pos_sharding = NamedSharding(mesh_cp, pos_spec) - - lnx = jax.device_put(lnx, lnx_sharding) - decoder_segment_ids = jax.device_put(decoder_segment_ids, pos_sharding) - decoder_positions = jax.device_put(decoder_positions, pos_sharding) - - attention_cp_output, _ = attention_cp( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) - attention_cp_output = attention_cp_output[0] if isinstance(attention_cp_output, tuple) else attention_cp_output - - # If load balanced cp, de-shuffle and gather along seq dim for output - # Note training does not need post-shuffle. Since the target seq is also pre-shuffled, the loss remains correct - if context_parallel_size > 1 and cfg_cp.context_parallel_load_balance: - attention_cp_output = max_utils.reorder_sequence( - tensor=attention_cp_output, cp_size=context_parallel_size, seq_dim=1, to_contiguous=True - ) - return attention_cp_output - - if __name__ == "__main__": unittest.main() diff --git a/tests/attention_test_util.py b/tests/attention_test_util.py new file mode 100644 index 000000000..f761e3511 --- /dev/null +++ b/tests/attention_test_util.py @@ -0,0 +1,236 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test util for attention tests.""" + +import os +import sys + +from absl.testing import parameterized +from flax import nnx +from flax.linen import partitioning as nn_partitioning +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding +from MaxText import max_utils +from MaxText import maxtext_utils +from MaxText import pyconfig +from MaxText.common_types import AttentionType, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, ShardMode +from MaxText.globals import MAXTEXT_PKG_DIR +from MaxText.layers.attention_mla import MLA + + +class MLATestBase(parameterized.TestCase): + """Test base for MLATest.""" + + config_arguments = { + "per_device_batch_size": 1.0, + "run_name": "test", + "enable_checkpointing": False, + "max_target_length": 128, + "max_prefill_predict_length": 16, + "attention_type": AttentionType.MLA.value, + "head_dim": 192, + "q_lora_rank": 10, + "kv_lora_rank": 20, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 192, + } + + def setUp(self): + """Initializes the configuration for each test""" + super().setUp() + jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) + config = pyconfig.initialize( + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + **self.config_arguments, + ) + self.cfg = config + self.rng = jax.random.PRNGKey(0) + self.nnx_rng = nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)) + devices_array = maxtext_utils.create_device_mesh(self.cfg) + self.mesh = Mesh(devices_array, self.cfg.mesh_axes) + + def init_mla(self, config_arguments, rope_type): + """Helper function to initialize MLA with different model names.""" + cfg = pyconfig.initialize( + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + **config_arguments, + rope_type=rope_type, + ) + + devices_array = maxtext_utils.create_device_mesh(cfg) + mesh = Mesh(devices_array, cfg.mesh_axes) + + dummy_inputs_q = jnp.ones(( + cfg.global_batch_size_to_train_on, + cfg.max_target_length, + cfg.base_emb_dim, + )) + dummy_inputs_kv = jnp.ones(( + cfg.global_batch_size_to_train_on, + cfg.max_target_length, + cfg.base_emb_dim, + )) + + mla = MLA( + config=cfg, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + mesh=mesh, + attention_kernel="dot_product", + dtype=cfg.dtype, + dropout_rate=cfg.dropout_rate, + attention_type=cfg.attention_type, + q_lora_rank=cfg.q_lora_rank, + kv_lora_rank=cfg.kv_lora_rank, + qk_nope_head_dim=cfg.qk_nope_head_dim, + qk_rope_head_dim=cfg.qk_rope_head_dim, + v_head_dim=cfg.v_head_dim, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) + + return cfg, mla + + def get_data(self, cfg, dtype): + """get data""" + lnx = jax.random.normal( + self.rng, + shape=( + cfg.global_batch_size_to_train_on, + cfg.max_target_length, + cfg.base_emb_dim, + ), + dtype=dtype, + ) + + decoder_segment_ids = jax.random.randint( + self.rng, + (cfg.global_batch_size_to_train_on, cfg.max_target_length), + 0, + 4, + ) + # decoder_segment_ids = None + decoder_positions = jax.random.randint( + self.rng, + (cfg.global_batch_size_to_train_on, cfg.max_target_length), + 0, + cfg.max_target_length, + ) + + return lnx, decoder_segment_ids, decoder_positions + + def get_structured_data(self, cfg, dtype): + """get structured data""" + lnx = jax.random.normal( + self.rng, + shape=( + cfg.global_batch_size_to_train_on, + cfg.max_target_length, + cfg.base_emb_dim, + ), + dtype=dtype, + ) + + decoder_positions = jnp.stack([ + jnp.arange(cfg.max_target_length, dtype=jnp.int32) + for _ in range(cfg.global_batch_size_to_train_on) + ]) + + decoder_segment_ids = ( + jax.numpy.zeros( + (cfg.global_batch_size_to_train_on, cfg.max_target_length) + ) + + DECODING_ACTIVE_SEQUENCE_INDICATOR + ) + + return lnx, decoder_segment_ids, decoder_positions + + def forward_with_context_expert_parallelism( + self, + cfg_cp, + mesh_cp, + attention_cp, + lnx, + decoder_segment_ids, + decoder_positions, + ): + """Get logits from attention under context/expert parallelism.""" + # If load balanced cp, shuffle along seq dim for input + # This corresponds to the pre-shuffle step in training + context_parallel_size = cfg_cp.context_parallel_size + if context_parallel_size > 1 and cfg_cp.context_parallel_load_balance: + batch = { + "inputs": lnx, + "inputs_segmentation": decoder_segment_ids, + "inputs_position": decoder_positions, + } + with mesh_cp: + reordered_batch = maxtext_utils.get_reorder_callable( + context_parallel_size, ShardMode.AUTO + )(batch) + lnx = reordered_batch["inputs"] + decoder_segment_ids = reordered_batch["inputs_segmentation"] + decoder_positions = reordered_batch["inputs_position"] + # apply attention with sharding + with mesh_cp, nn_partitioning.axis_rules(cfg_cp.logical_axis_rules): + lnx_spec = nn_partitioning.logical_to_mesh_axes( + ( + "activation_batch_no_exp", + "activation_length_no_exp", + "activation_embed", + ), + nn_partitioning.get_axis_rules(), + ) + pos_spec = nn_partitioning.logical_to_mesh_axes( + ("activation_batch_no_exp", "activation_length_no_exp"), + nn_partitioning.get_axis_rules(), + ) + lnx_sharding = NamedSharding(mesh_cp, lnx_spec) + pos_sharding = NamedSharding(mesh_cp, pos_spec) + + lnx = jax.device_put(lnx, lnx_sharding) + decoder_segment_ids = jax.device_put(decoder_segment_ids, pos_sharding) + decoder_positions = jax.device_put(decoder_positions, pos_sharding) + + attention_cp_output, _ = attention_cp( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + attention_cp_output = ( + attention_cp_output[0] + if isinstance(attention_cp_output, tuple) + else attention_cp_output + ) + + # If load balanced cp, de-shuffle and gather along seq dim for output + # Note training does not need post-shuffle. Since the target seq is also pre-shuffled, the loss remains correct + if context_parallel_size > 1 and cfg_cp.context_parallel_load_balance: + attention_cp_output = max_utils.reorder_sequence( + tensor=attention_cp_output, + cp_size=context_parallel_size, + seq_dim=1, + to_contiguous=True, + ) + return attention_cp_output