Skip to content

Commit ff5be4a

Browse files
Merge pull request #2529 from AI-Hypercomputer:rbierneni-qwen3-next-fullattention
PiperOrigin-RevId: 827996001
2 parents 0337876 + ee4b38a commit ff5be4a

File tree

8 files changed

+703
-142
lines changed

8 files changed

+703
-142
lines changed

src/MaxText/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,8 @@ gdn_num_value_heads: 32
905905
gdn_chunk_size: 64
906906
# Whether to apply L2 normalization to query and key tensors inside the Gated Delta Rule kernel.
907907
use_qk_norm_in_gdn: True
908+
# The ratio of dimension to apply ROPE on
909+
partial_rotary_factor: 1.0
908910

909911
# Use tokamax library for gmm kernel implementation
910912
use_tokamax_gmm: false

src/MaxText/configs/models/qwen3-next-80b-a3b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,4 @@ gdn_chunk_size: 64
4545

4646
# RoPE Settings
4747
rope_max_timescale: 10000000
48+
partial_rotary_factor: 0.25

src/MaxText/layers/attentions.py

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,11 @@
6565
LlamaVisionRotaryEmbedding,
6666
RotaryEmbedding,
6767
YarnRotaryEmbedding,
68+
Qwen3NextRotaryEmbedding,
6869
)
6970
from MaxText.layers.initializers import nd_dense_init, NdInitializer, variable_to_logically_partitioned, default_bias_init
7071
from MaxText.layers.linears import DenseGeneral, canonicalize_tuple, normalize_axes
71-
from MaxText.layers.normalizations import RMSNorm
72+
from MaxText.layers.normalizations import RMSNorm, Qwen3NextRMSNorm
7273
from MaxText.layers.quantizations import AqtQuantization as Quant
7374

7475
# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes
@@ -416,6 +417,8 @@ def __init__(
416417
self.model_mode = model_mode
417418
self.rngs = rngs
418419

420+
self.is_qwen3_next = self.config.decoder_block == DecoderBlockType.QWEN3_NEXT
421+
419422
# Module attribute names must match names previously passed to Linen for checkpointing
420423
self.KVCache_0 = (
421424
self.init_kv_caches(inputs_kv_shape=inputs_kv_shape)
@@ -478,6 +481,9 @@ def __init__(
478481
else:
479482
self.sinks = None
480483

484+
self.query_norm = None
485+
self.key_norm = None
486+
481487
is_llama4_decoder_block = self.config.decoder_block == DecoderBlockType.LLAMA4
482488
if self.use_qk_norm and not is_llama4_decoder_block:
483489
self.query_norm = RMSNorm(
@@ -498,9 +504,21 @@ def __init__(
498504
kernel_axes=("norm",),
499505
rngs=self.rngs,
500506
)
501-
else:
502-
self.query_norm = None
503-
self.key_norm = None
507+
elif self.is_qwen3_next:
508+
self.query_norm = Qwen3NextRMSNorm(
509+
num_features=self.config.head_dim,
510+
eps=self.config.normalization_layer_epsilon,
511+
dtype=self.config.dtype,
512+
weight_dtype=self.config.weight_dtype,
513+
rngs=self.rngs,
514+
)
515+
self.key_norm = Qwen3NextRMSNorm(
516+
num_features=self.config.head_dim,
517+
eps=self.config.normalization_layer_epsilon,
518+
dtype=self.config.dtype,
519+
weight_dtype=self.config.weight_dtype,
520+
rngs=self.rngs,
521+
)
504522

505523
self._maybe_shard_with_logical = functools.partial(
506524
maybe_shard_with_logical,
@@ -538,9 +556,15 @@ def query_init(*args):
538556
kernel_axes = (
539557
(None, None, None) if self.config.ici_context_autoregressive_parallelism > 1 else ("embed", "q_heads", "kv")
540558
)
559+
in_features = self.convert_dense_general_inputs_shape(inputs_q_shape)
560+
out_features = (self.num_query_heads, self.head_dim)
561+
562+
if self.is_qwen3_next:
563+
out_features = (self.num_query_heads, self.head_dim * 2)
564+
541565
return DenseGeneral(
542-
in_features_shape=self.convert_dense_general_inputs_shape(inputs_q_shape),
543-
out_features_shape=(self.num_query_heads, self.head_dim),
566+
in_features_shape=in_features,
567+
out_features_shape=out_features,
544568
axis=-1,
545569
kernel_init=query_init,
546570
kernel_axes=kernel_axes,
@@ -642,13 +666,22 @@ def qkv_projection(self, inputs: Array, proj_name: str, out_sharding: NamedShard
642666

643667
def init_out_w(self, output_dim: int) -> nnx.Module:
644668
"""out projection"""
669+
in_features = (self.num_query_heads, self.head_dim)
670+
out_features = output_dim
645671
out_kernel_axis = (
646672
(None, None, None) if self.config.ici_context_autoregressive_parallelism > 1 else ("heads", "kv", "embed")
647673
)
674+
axis = (-2, -1)
675+
676+
if self.is_qwen3_next:
677+
in_features = self.num_query_heads * self.head_dim
678+
out_kernel_axis = ("mlp", "embed")
679+
axis = (-1,)
680+
648681
return DenseGeneral(
649-
in_features_shape=(self.num_query_heads, self.head_dim),
650-
out_features_shape=output_dim,
651-
axis=(-2, -1),
682+
in_features_shape=in_features,
683+
out_features_shape=out_features,
684+
axis=axis,
652685
kernel_init=self.kernel_init,
653686
kernel_axes=out_kernel_axis, # trade speed with memory
654687
dtype=self.dtype,
@@ -720,6 +753,16 @@ def init_rotary_embedding(self):
720753
attention_scaling=self.config.rope_attention_scaling,
721754
rngs=self.rngs,
722755
)
756+
elif self.is_qwen3_next:
757+
rotary_embedding = Qwen3NextRotaryEmbedding(
758+
min_timescale=self.config.rope_min_timescale,
759+
max_timescale=self.config.rope_max_timescale,
760+
embedding_dims=self.config.head_dim,
761+
partial_rotary_factor=self.config.partial_rotary_factor,
762+
cast_as_fprop_dtype=True,
763+
fprop_dtype=self.config.dtype,
764+
rngs=self.rngs,
765+
)
723766
else:
724767
max_timescale = self.config.rope_max_timescale
725768
# For local attention use local_rope_max_timescale if it's is positive
@@ -890,9 +933,17 @@ def __call__(
890933
value_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(self.value_axis_names))
891934
value = self.kv_projection(inputs_kv, proj_name="value", out_sharding=value_sharding)
892935

936+
gate = None
937+
if self.is_qwen3_next:
938+
# Split query into query & gate.
939+
query, gate = jnp.split(query, 2, axis=-1)
940+
batch_size, seq_len, _, _ = gate.shape
941+
gate = gate.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim)
942+
893943
is_llama4_decoder_block = self.config.decoder_block == DecoderBlockType.LLAMA4
894944
# NOTE: llama 4 does L2 normalization after RoPE
895-
if self.use_qk_norm and not is_llama4_decoder_block:
945+
# Apply Qwen3Next specific RMS Norm
946+
if (self.use_qk_norm and not is_llama4_decoder_block) or self.is_qwen3_next:
896947
query = self.query_norm(query)
897948
key = self.key_norm(key)
898949

@@ -964,7 +1015,9 @@ def __call__(
9641015
bidirectional_mask,
9651016
self.sinks,
9661017
)
967-
1018+
if self.is_qwen3_next:
1019+
out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim)
1020+
out = out * jax.nn.sigmoid(gate)
9681021
if model_mode == MODEL_MODE_PREFILL:
9691022
out = self._maybe_shard_with_logical(out, self.prefill_out_axis_names)
9701023
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:

src/MaxText/layers/embeddings.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,97 @@ def llama_rotary_embedding_as_linen(
380380
)
381381

382382

383+
def qwen3_next_rotary_embedding_as_linen(
384+
*,
385+
min_timescale: int,
386+
max_timescale: int,
387+
embedding_dims: int = 0,
388+
partial_rotary_factor: float = 0.25,
389+
cast_as_fprop_dtype: bool = True,
390+
fprop_dtype: DType = jnp.bfloat16,
391+
name: str | None = None,
392+
):
393+
"""Initializes the Qwen3NextRotaryEmbedding module and returns it as a Linen module.
394+
395+
Args:
396+
min_timescale: Start of the geometric index. Determines the periodicity of
397+
the added signal.
398+
max_timescale: End of the geometric index. Determines the frequency of the
399+
added signal.
400+
embedding_dims: Dimension of the embedding to be generated.
401+
partial_rotary_factor: Ratio of dimensions to apply ROPE to.
402+
cast_as_fprop_dtype: Whether to cast the output to the fprop dtype.
403+
fprop_dtype: The dtype of the output.
404+
name: Name of the Linen module.
405+
"""
406+
return nnx_wrappers.to_linen(
407+
Qwen3NextRotaryEmbedding,
408+
min_timescale=min_timescale,
409+
max_timescale=max_timescale,
410+
embedding_dims=embedding_dims,
411+
partial_rotary_factor=partial_rotary_factor,
412+
cast_as_fprop_dtype=cast_as_fprop_dtype,
413+
fprop_dtype=fprop_dtype,
414+
metadata_fn=variable_to_logically_partitioned,
415+
name=name,
416+
)
417+
418+
419+
class Qwen3NextRotaryEmbedding(RotaryEmbedding):
420+
"""Qwen3 Next variant of ROPE (partial ROPE)"""
421+
422+
def __init__(
423+
self,
424+
min_timescale: int,
425+
max_timescale: int,
426+
embedding_dims: int = 0,
427+
cast_as_fprop_dtype: bool = True,
428+
fprop_dtype: DType = jnp.bfloat16,
429+
partial_rotary_factor: float = 0.25,
430+
rngs: nnx.Rngs = None,
431+
):
432+
"""Initializes the Qwen3NextRotaryEmbedding module.
433+
434+
Args:
435+
min_timescale: Start of the geometric index. Determines the periodicity of
436+
the added signal.
437+
max_timescale: End of the geometric index. Determines the frequency of the
438+
added signal.
439+
embedding_dims: Dimension of the embedding to be generated.
440+
partial_rotary_factor: Ratio of dimensions to apply ROPE to
441+
rngs: rng keys passed in by nnx.bridge.to_linen.
442+
"""
443+
self.head_dim = embedding_dims
444+
self.partial_rotary_factor = partial_rotary_factor
445+
self.rotary_dim = int(self.head_dim * self.partial_rotary_factor)
446+
447+
super().__init__(
448+
min_timescale=min_timescale,
449+
max_timescale=max_timescale,
450+
embedding_dims=self.rotary_dim,
451+
cast_as_fprop_dtype=cast_as_fprop_dtype,
452+
fprop_dtype=fprop_dtype,
453+
rngs=rngs,
454+
)
455+
456+
def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax.Array:
457+
"""Applies LLaMA variant of rotary position embedding.
458+
459+
Args:
460+
inputs: The input sequence on which to apply the Rotary position
461+
embedding. It is assumed of shape [B, S, H, D].
462+
position: Optional position array [B, S]. Only needed when the sequence
463+
is packed.
464+
465+
Returns:
466+
A jax.Array of shape [B, S, H, D - rotary_dim] with rotary position embeddings applied.
467+
"""
468+
inputs_rot, inputs_pass = jnp.split(inputs, [self.rotary_dim], axis=-1)
469+
inputs_rot = super().__call__(inputs_rot, position)
470+
inputs = jnp.concatenate([inputs_rot, inputs_pass], axis=-1)
471+
return inputs
472+
473+
383474
class LLaMARotaryEmbedding(RotaryEmbedding):
384475
"""LLaMA variant of ROPE."""
385476

src/MaxText/layers/normalizations.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from flax import linen as nn
2020
from flax import nnx
21+
from flax.linen import initializers as linen_initializers
2122
from jax import lax
2223
import jax
2324
import jax.numpy as jnp
@@ -26,7 +27,7 @@
2627
from MaxText import max_utils
2728
from MaxText.layers import nnx_wrappers
2829
from MaxText.layers.initializers import Initializer, variable_to_logically_partitioned
29-
from MaxText.common_types import Array, ShardMode
30+
from MaxText.common_types import Array, DType, ShardMode
3031

3132

3233
class RMSNorm(nnx.Module):
@@ -42,6 +43,7 @@ def __init__(
4243
kernel_axes: tuple[None | str, ...] = (),
4344
scale_init: Initializer = nn.initializers.ones,
4445
parameter_memory_host_offload: bool = False,
46+
scale_offset: float = 0.0,
4547
*,
4648
rngs: nnx.Rngs,
4749
):
@@ -53,6 +55,7 @@ def __init__(
5355
self.kernel_axes = kernel_axes
5456
self.scale_init = scale_init
5557
self.parameter_memory_host_offload = parameter_memory_host_offload
58+
self.scale_offset = scale_offset
5659
self.scale = nnx.Param(
5760
scale_init(rngs.params(), (num_features,), weight_dtype),
5861
sharding=kernel_axes,
@@ -73,8 +76,83 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
7376
out_sharding = None
7477

7578
scale = jnp.asarray(scale, self.dtype)
79+
effective_scale = scale + self.scale_offset # Apply offset
7680
# broadcast 2nd input then element-wise mul
77-
return jnp.einsum("i...k,...k->i...k", y, scale, out_sharding=out_sharding)
81+
return jnp.einsum("i...k,...k->i...k", y, effective_scale, out_sharding=out_sharding)
82+
83+
84+
def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs):
85+
"""
86+
Used for input and post attention layernorms
87+
in Qwen3NextDecoderLayer.
88+
89+
This normalization layer is specific to Qwen3-Next. Key characteristics:
90+
1. The learnable scale parameter `scale` is initialized to ZEROS.
91+
2. The scale is applied as `(1.0 + self.scale)`, making the initial scale effectively 1.0.
92+
This matches the PyTorch implementation of Qwen3NextRMSNorm.
93+
"""
94+
return nnx.data(
95+
RMSNorm(
96+
num_features=num_features,
97+
epsilon=eps,
98+
dtype=dtype,
99+
weight_dtype=weight_dtype,
100+
scale_init=linen_initializers.zeros,
101+
scale_offset=1.0,
102+
rngs=rngs,
103+
)
104+
)
105+
106+
107+
class Qwen3NextRMSNormGated(nnx.Module):
108+
"""
109+
This applies RMS Normalization and then a gated activation function (SiLU).
110+
This is used within the Qwen3NextGatedDeltaNet.
111+
112+
The normalization is performed by an internal `RMSNorm` instance (`self.rms_norm`),
113+
which has its own learnable `scale` parameter, initialized to ONES.
114+
115+
Attributes:
116+
num_features: The number of features in the input.
117+
eps: A small epsilon value to prevent division by zero in RMSNorm.
118+
dtype: The datatype of the computation.
119+
weight_dtype: The datatype of the internal RMSNorm scale.
120+
"""
121+
122+
def __init__(self, num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs):
123+
self.num_features = num_features
124+
self.eps = eps
125+
self.dtype = dtype
126+
self.weight_dtype = weight_dtype
127+
self.rms_norm = nnx.data(
128+
RMSNorm(
129+
num_features=num_features,
130+
epsilon=eps,
131+
dtype=dtype,
132+
weight_dtype=weight_dtype,
133+
scale_init=nnx.initializers.ones,
134+
rngs=rngs,
135+
)
136+
)
137+
138+
def __call__(self, hidden_states: Array, gate: Array) -> Array:
139+
"""
140+
Applies RMSNorm and then a SiLU gate.
141+
142+
Args:
143+
hidden_states: The input array to be normalized (o). Shape: (..., F)
144+
gate: The gating array for the activation (z). Shape: (..., F)
145+
where F is num_features.
146+
147+
Returns:
148+
The normalized and gated output array. Shape: (..., F)
149+
"""
150+
normalized_states = self.rms_norm(hidden_states)
151+
152+
# Gated Activation using SiLU (Sigmoid-weighted Linear Unit)
153+
gated_states = normalized_states * jax.nn.silu(gate.astype(jnp.float32))
154+
155+
return gated_states.astype(self.dtype)
78156

79157

80158
def rms_norm(

0 commit comments

Comments
 (0)