Skip to content

Commit e02aa18

Browse files
NuojChengGoogle-ML-Automation
authored andcommitted
Copybara import of the project:
-- e43c2e1 by NuojCheng <[email protected]>: remove size one mesh axis COPYBARA_INTEGRATE_REVIEW=#2737 from AI-Hypercomputer:chengnuojin-remove-sizeone e43c2e1 PiperOrigin-RevId: 842434394
1 parent 5489efd commit e02aa18

File tree

12 files changed

+152
-126
lines changed

12 files changed

+152
-126
lines changed

src/MaxText/inference/paged_attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from MaxText.inference import page_manager
3232
from MaxText.inference import paged_attention_kernel_v2
33+
from MaxText.sharding import logical_to_mesh_axes
3334
from MaxText.common_types import Array, DType, AxisNames, BATCH, LENGTH, HEAD, D_KV, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
3435
from MaxText.layers.initializers import variable_to_logically_partitioned
3536

@@ -322,8 +323,8 @@ def paged_attention_v1_decode(
322323
page_state: page_manager.PageState,
323324
) -> Array:
324325
"""Apply Paged Attention v1 in decode only."""
325-
kv_pages_pspec = nn.logical_to_mesh_axes(("paged_kv_heads", None, None, None))
326-
q_pspec = nn.logical_to_mesh_axes((None, None, "paged_kv_heads", None))
326+
kv_pages_pspec = logical_to_mesh_axes(("paged_kv_heads", None, None, None), self.mesh)
327+
q_pspec = logical_to_mesh_axes((None, None, "paged_kv_heads", None), self.mesh)
327328

328329
@functools.partial(
329330
jax.shard_map,

src/MaxText/layers/attention_mla.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from jax.sharding import Mesh, NamedSharding
2222
import jax.numpy as jnp
2323

24-
from flax import linen as nn
2524
from flax import nnx
2625

2726
from MaxText.common_types import (
@@ -56,6 +55,7 @@
5655
from MaxText.inference import page_manager
5756
from MaxText.inference import paged_attention
5857
from MaxText.inference.kvcache import KVQuant
58+
from MaxText.sharding import create_sharding
5959
from MaxText.layers import nnx_wrappers
6060
from MaxText.layers.attentions import Attention
6161
from MaxText.layers.initializers import nd_dense_init, NdInitializer, variable_to_logically_partitioned
@@ -515,8 +515,8 @@ def mla_query_projection(self, inputs_q: Array, inputs_positions: Array, model_m
515515
else:
516516
query_logical_name = self.query_axis_names
517517
wqa_logical_name = (KV_BATCH, LENGTH_NO_EXP, Q_LORA_UP_PROJ)
518-
query_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(query_logical_name))
519-
wqa_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(wqa_logical_name))
518+
query_sharding = create_sharding(self.mesh, query_logical_name)
519+
wqa_out_sharding = create_sharding(self.mesh, wqa_logical_name)
520520
# Set softmax scaling.
521521
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
522522
self.softmax_scale = self.qk_head_dim**-0.5
@@ -555,7 +555,7 @@ def mla_get_key_value(self, low_rank_main, key_rope, model_mode):
555555
key_logical_name = self.key_axis_names
556556
value_logical_name = self.value_axis_names
557557

558-
wkva_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(key_logical_name))
558+
wkva_out_sharding = create_sharding(self.mesh, key_logical_name)
559559
kv_out = self.wkv_b(low_rank_main, out_sharding=wkva_out_sharding)
560560

561561
# Split kv_out into key_nope and value parts.
@@ -664,7 +664,7 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
664664
wka_logical_name = (KV_BATCH_NO_EXP, LENGTH, KV_LORA_UP_PROJ)
665665
else:
666666
wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ)
667-
wkva_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(wka_logical_name))
667+
wkva_out_sharding = create_sharding(self.mesh, wka_logical_name)
668668
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
669669
low_rank_main, low_rank_rope = jnp.split(low_rank, [self.kv_lora_rank], axis=-1)
670670
low_rank_main = self.kv_norm(low_rank_main)
@@ -759,7 +759,7 @@ def __call__(
759759
else:
760760
out = self._maybe_shard_with_logical(out, self.out_axis_names)
761761

762-
out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(out_logical_name))
762+
out_sharding = create_sharding(self.mesh, out_logical_name)
763763
out = self.out_projection(out, out_sharding=out_sharding)
764764
out = checkpoint_name(out, "out_proj")
765765
return out, kv_cache

src/MaxText/layers/attention_op.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444

4545
from MaxText import max_utils
46-
from MaxText.sharding import maybe_shard_with_name
46+
from MaxText.sharding import maybe_shard_with_name, logical_to_mesh_axes
4747
from MaxText.common_types import (
4848
DEFAULT_MASK_VALUE,
4949
BATCH,
@@ -530,6 +530,9 @@ def maybe_create_nnx(einsum, *args):
530530
self.AqtEinsum_2 = jnp.einsum
531531
self.AqtEinsum_3 = jnp.einsum
532532

533+
def _logical_to_mesh_axes(self, logical_name):
534+
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=self.config.logical_axis_rules)
535+
533536
def check_attention_inputs(self, query: Array, key: Array | KVTensor, value: Array | KVTensor) -> None:
534537
"""Check attention inputs."""
535538

@@ -950,10 +953,10 @@ def gpu_ragged_attention(self, q: Array, k: Array | KVTensor, v: Array | KVTenso
950953
q_for_gqa = q.squeeze(axis=1)
951954

952955
# Define logical axis names - clearer and avoids repeated calls.
953-
b = nn.logical_to_mesh_axes(self.ragged_lengths_names)
954-
bsnd = nn.logical_to_mesh_axes(self.cache_logical_axis_names)
955-
bnd = nn.logical_to_mesh_axes((CACHE_BATCH, CACHE_HEADS, CACHE_KV))
956-
bn = nn.logical_to_mesh_axes((CACHE_BATCH, CACHE_HEADS))
956+
b = self._logical_to_mesh_axes(self.ragged_lengths_names)
957+
bsnd = self._logical_to_mesh_axes(self.cache_logical_axis_names)
958+
bnd = self._logical_to_mesh_axes((CACHE_BATCH, CACHE_HEADS, CACHE_KV))
959+
bn = self._logical_to_mesh_axes((CACHE_BATCH, CACHE_HEADS))
957960

958961
@functools.partial(
959962
jax.shard_map,
@@ -1006,8 +1009,8 @@ def tpu_ragged_attention(
10061009
"""Ragged Attention."""
10071010
if isinstance(query, KVTensor):
10081011
raise TypeError("Ragged attention does not currently support quantized tensors.")
1009-
b = nn.logical_to_mesh_axes(self.ragged_lengths_names)
1010-
bsnd = nn.logical_to_mesh_axes(self.cache_logical_axis_names)
1012+
b = self._logical_to_mesh_axes(self.ragged_lengths_names)
1013+
bsnd = self._logical_to_mesh_axes(self.cache_logical_axis_names)
10111014

10121015
@functools.partial(
10131016
jax.shard_map,
@@ -1050,23 +1053,23 @@ def tpu_flash_attention(
10501053
value = jnp.transpose(value, axes=(0, 2, 1, 3))
10511054
segment_axis_names_q = None
10521055
segment_axis_names_kv = None
1053-
sink_axis_names = nn.logical_to_mesh_axes((HEAD,))
1056+
sink_axis_names = self._logical_to_mesh_axes((HEAD,))
10541057
if decoder_segment_ids is not None:
10551058
if self.config.expert_shard_attention_option == EP_AS_CONTEXT:
1056-
segment_axis_names_q = nn.logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH))
1057-
segment_axis_names_kv = nn.logical_to_mesh_axes((BATCH_NO_EXP, KV_LENGTH))
1059+
segment_axis_names_q = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH))
1060+
segment_axis_names_kv = self._logical_to_mesh_axes((BATCH_NO_EXP, KV_LENGTH))
10581061
else:
1059-
segment_axis_names_q = nn.logical_to_mesh_axes((BATCH, Q_LENGTH_NO_EXP))
1060-
segment_axis_names_kv = nn.logical_to_mesh_axes((BATCH, KV_LENGTH))
1062+
segment_axis_names_q = self._logical_to_mesh_axes((BATCH, Q_LENGTH_NO_EXP))
1063+
segment_axis_names_kv = self._logical_to_mesh_axes((BATCH, KV_LENGTH))
10611064

10621065
if self.config.expert_shard_attention_option == EP_AS_CONTEXT:
1063-
axis_names_splash_kernel = nn.logical_to_mesh_axes(self.flash_axis_names_splash_kernel_ep)
1064-
axis_names_q = nn.logical_to_mesh_axes(self.flash_axis_names_q_ep)
1065-
axis_names_kv = nn.logical_to_mesh_axes(self.flash_axis_names_kv_ep)
1066+
axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel_ep)
1067+
axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q_ep)
1068+
axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv_ep)
10661069
else:
1067-
axis_names_splash_kernel = nn.logical_to_mesh_axes(self.flash_axis_names_splash_kernel)
1068-
axis_names_q = nn.logical_to_mesh_axes(self.flash_axis_names_q)
1069-
axis_names_kv = nn.logical_to_mesh_axes(self.flash_axis_names_kv)
1070+
axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel)
1071+
axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q)
1072+
axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv)
10701073

10711074
global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv
10721075
global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel
@@ -1197,9 +1200,9 @@ def wrap_splash_kernel(single_head_mask, shard_head_size=1):
11971200
shard_head_size = np.prod(logical_axis_rules_head)
11981201
splash_kernel = wrap_splash_kernel(single_head_mask, int(shard_head_size))
11991202
if self.config.expert_shard_attention_option == EP_AS_CONTEXT:
1200-
segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH,))
1203+
segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH,))
12011204
else:
1202-
segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH_NO_EXP,))
1205+
segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH_NO_EXP,))
12031206
else:
12041207
# Create multi-head mask
12051208
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])

src/MaxText/layers/attentions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import jax
2424
import jax.numpy as jnp
2525

26-
from flax import nnx, linen as nn
26+
from flax import nnx
2727

2828
from MaxText.common_types import (
2929
DecoderBlockType,
@@ -53,7 +53,7 @@
5353
EP_AS_CONTEXT,
5454
AttentionType,
5555
)
56-
from MaxText.sharding import maybe_shard_with_logical
56+
from MaxText.sharding import maybe_shard_with_logical, create_sharding
5757
from MaxText.inference import kvcache
5858
from MaxText.inference import page_manager
5959
from MaxText.inference import paged_attention
@@ -1003,7 +1003,7 @@ def __call__(
10031003

10041004
inputs_q = self._maybe_shard_with_logical(inputs_q, input_axis_names)
10051005
inputs_kv = self._maybe_shard_with_logical(inputs_kv, input_axis_names)
1006-
qkv_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(input_axis_names))
1006+
qkv_sharding = create_sharding(self.mesh, input_axis_names)
10071007

10081008
# apply projection.
10091009
if self.config.fused_qkv:

src/MaxText/layers/decoders.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import jax
2323
import jax.numpy as jnp
2424
from jax.ad_checkpoint import checkpoint_name
25-
from jax.sharding import Mesh, NamedSharding
25+
from jax.sharding import Mesh
2626

2727
from flax import linen as nn
2828
from flax import nnx
@@ -32,6 +32,7 @@
3232
from MaxText.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
3333
from MaxText import max_logging
3434
from MaxText import max_utils
35+
from MaxText.sharding import create_sharding
3536
from MaxText.inference import page_manager
3637
from MaxText.layers import linears
3738
from MaxText.layers import quantizations
@@ -607,16 +608,7 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi
607608

608609
cfg = self.config
609610
if cfg.shard_mode == ShardMode.EXPLICIT:
610-
norm_out_sharding = NamedSharding(
611-
self.mesh,
612-
nn.logical_to_mesh_axes(
613-
(
614-
"activation_batch",
615-
"activation_length_no_exp",
616-
"activation_embed",
617-
)
618-
),
619-
)
611+
norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed"))
620612
else:
621613
norm_out_sharding = None
622614

@@ -631,17 +623,10 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi
631623
y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic)
632624

633625
if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
634-
out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes((None, None, "activation_vocab")))
626+
out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab"))
635627
else:
636-
out_sharding = NamedSharding(
637-
self.mesh,
638-
nn.logical_to_mesh_axes(
639-
(
640-
"activation_embed_and_logits_batch",
641-
"activation_length_no_exp",
642-
"activation_vocab",
643-
)
644-
),
628+
out_sharding = create_sharding(
629+
self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab")
645630
)
646631

647632
# [batch, length, emb_dim] -> [batch, length, vocab_size]

src/MaxText/layers/deepseek.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from functools import partial
2020

2121
from jax.ad_checkpoint import checkpoint_name
22-
from jax.sharding import Mesh, NamedSharding
22+
from jax.sharding import Mesh
2323
import jax.numpy as jnp
2424

2525
from flax import linen as nn
@@ -32,7 +32,7 @@
3232
from MaxText.layers import moe
3333
from MaxText.layers import quantizations
3434
from MaxText.layers.quantizations import AqtQuantization as Quant
35-
from MaxText.sharding import maybe_shard_with_logical
35+
from MaxText.sharding import maybe_shard_with_logical, create_sharding
3636
from MaxText.inference import page_manager
3737
from MaxText.common_types import MODEL_MODE_PREFILL
3838

@@ -75,7 +75,7 @@ def self_attention_with_norm(
7575
mesh=mesh,
7676
shard_mode=cfg.shard_mode,
7777
)
78-
lnx_sharding = NamedSharding(mesh, nn.logical_to_mesh_axes(logical_axis_names))
78+
lnx_sharding = create_sharding(mesh, logical_axis_names)
7979
lnx = _maybe_shard_with_logical(lnx, logical_axis_names)
8080

8181
attention_layer = attention_mla.mla_as_linen(
@@ -189,8 +189,8 @@ def __call__(
189189
inputs = inputs[0]
190190

191191
_maybe_shard_with_logical = partial(maybe_shard_with_logical, mesh=self.mesh, shard_mode=self.config.shard_mode)
192-
lnx_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(logical_axis_names))
193-
mlp_intermediate_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(mlp_logical_axis_names))
192+
lnx_out_sharding = create_sharding(self.mesh, logical_axis_names)
193+
mlp_intermediate_sharding = create_sharding(self.mesh, mlp_logical_axis_names)
194194
inputs = _maybe_shard_with_logical(inputs, logical_axis_names)
195195
inputs = checkpoint_name(inputs, "decoder_layer_input")
196196

@@ -273,8 +273,8 @@ def __call__(
273273
inputs = inputs[0]
274274

275275
_maybe_shard_with_logical = partial(maybe_shard_with_logical, mesh=self.mesh, shard_mode=self.config.shard_mode)
276-
lnx_out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(logical_axis_names))
277-
lnx_intermediate_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(mlp_logical_axis_names))
276+
lnx_out_sharding = create_sharding(self.mesh, logical_axis_names)
277+
lnx_intermediate_sharding = create_sharding(self.mesh, mlp_logical_axis_names)
278278
inputs = _maybe_shard_with_logical(inputs, logical_axis_names)
279279
inputs = checkpoint_name(inputs, "decoder_layer_input")
280280

src/MaxText/layers/embeddings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
import jax.numpy as jnp
2323
from jax.sharding import Mesh, NamedSharding
2424

25-
from flax import linen as nn
2625
from flax import nnx
2726

2827
from MaxText import max_logging
2928
from MaxText import max_utils
29+
from MaxText.sharding import logical_to_mesh_axes, create_sharding
3030
from MaxText.common_types import ShardMode, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, Array, Config, DType
3131
from MaxText.layers import nnx_wrappers
3232
from MaxText.layers.initializers import Initializer, default_embed_init, variable_to_logically_partitioned
@@ -169,7 +169,7 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array:
169169
"activation_embed",
170170
)
171171
)
172-
out_pspec = nn.logical_to_mesh_axes(output_axis_names)
172+
out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh)
173173

174174
out_sharding = NamedSharding(self.mesh, out_pspec) if self.config.shard_mode == ShardMode.EXPLICIT else None
175175

@@ -751,7 +751,7 @@ def __init__(
751751
self.attention_scaling = attention_scaling
752752

753753
self.freqs_sharding = (
754-
NamedSharding(mesh, nn.logical_to_mesh_axes(("activation_batch", "activation_length_no_exp", "q_heads")))
754+
create_sharding(mesh, ("activation_batch", "activation_length_no_exp", "q_heads"))
755755
if shard_mode == ShardMode.EXPLICIT
756756
else None
757757
)
@@ -877,7 +877,7 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
877877
inputs_complex = first_half + 1j * second_half # shape: [B, S, N, half_dim]
878878
# Apply the rotary transformation via complex multiplication.
879879
rotated_sharding = (
880-
NamedSharding(self.mesh, nn.logical_to_mesh_axes(("activation_batch", "activation_length_no_exp", None, None)))
880+
create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", None, None))
881881
if self.shard_mode == ShardMode.EXPLICIT
882882
else None
883883
)

src/MaxText/layers/llama2.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,14 @@
1919
import functools
2020
import jax.numpy as jnp
2121
from jax.ad_checkpoint import checkpoint_name
22-
from jax.sharding import Mesh, NamedSharding
22+
from jax.sharding import Mesh
2323

24-
from flax import linen as nn
2524
from flax import nnx
2625

2726
from MaxText.inference import page_manager
2827
from MaxText.common_types import Config
2928
from MaxText import max_utils
30-
from MaxText.sharding import maybe_shard_with_logical
29+
from MaxText.sharding import maybe_shard_with_logical, create_sharding
3130
from MaxText.layers.linears import Dropout, MlpBlock
3231
from MaxText.layers import initializers
3332
from MaxText.layers import nnx_wrappers
@@ -157,7 +156,7 @@ def __call__(
157156
inputs = inputs[0]
158157
inputs = self._maybe_shard_with_logical(inputs, self.activation_axis_names)
159158
inputs = checkpoint_name(inputs, "decoder_layer_input")
160-
lnx_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(self.activation_axis_names))
159+
lnx_sharding = create_sharding(self.mesh, self.activation_axis_names)
161160
lnx = self.pre_self_attention_layer_norm(inputs, out_sharding=lnx_sharding)
162161
lnx = self._maybe_shard_with_logical(lnx, self.activation_axis_names)
163162

@@ -185,9 +184,8 @@ def __call__(
185184
hidden_states = self._maybe_shard_with_logical(hidden_states, self.activation_axis_names)
186185

187186
# MLP block.
188-
mlp_intermediate_sharding = NamedSharding(
189-
self.mesh,
190-
nn.logical_to_mesh_axes(("activation_batch", "activation_length_no_exp", "activation_mlp")),
187+
mlp_intermediate_sharding = create_sharding(
188+
self.mesh, ("activation_batch", "activation_length_no_exp", "activation_mlp")
191189
)
192190
mlp_lnx = self.mlp(
193191
hidden_states,

0 commit comments

Comments
 (0)