6565 LlamaVisionRotaryEmbedding ,
6666 RotaryEmbedding ,
6767 YarnRotaryEmbedding ,
68+ Qwen3NextRotaryEmbedding ,
6869)
6970from MaxText .layers .initializers import nd_dense_init , NdInitializer , variable_to_logically_partitioned , default_bias_init
7071from MaxText .layers .linears import DenseGeneral , canonicalize_tuple , normalize_axes
71- from MaxText .layers .normalizations import RMSNorm
72+ from MaxText .layers .normalizations import RMSNorm , Qwen3NextRMSNorm
7273from MaxText .layers .quantizations import AqtQuantization as Quant
7374
7475# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes
@@ -416,6 +417,8 @@ def __init__(
416417 self .model_mode = model_mode
417418 self .rngs = rngs
418419
420+ self .is_qwen3_next = self .config .decoder_block == DecoderBlockType .QWEN3_NEXT
421+
419422 # Module attribute names must match names previously passed to Linen for checkpointing
420423 self .KVCache_0 = (
421424 self .init_kv_caches (inputs_kv_shape = inputs_kv_shape )
@@ -478,6 +481,9 @@ def __init__(
478481 else :
479482 self .sinks = None
480483
484+ self .query_norm = None
485+ self .key_norm = None
486+
481487 is_llama4_decoder_block = self .config .decoder_block == DecoderBlockType .LLAMA4
482488 if self .use_qk_norm and not is_llama4_decoder_block :
483489 self .query_norm = RMSNorm (
@@ -498,9 +504,21 @@ def __init__(
498504 kernel_axes = ("norm" ,),
499505 rngs = self .rngs ,
500506 )
501- else :
502- self .query_norm = None
503- self .key_norm = None
507+ elif self .is_qwen3_next :
508+ self .query_norm = Qwen3NextRMSNorm (
509+ num_features = self .config .head_dim ,
510+ eps = self .config .normalization_layer_epsilon ,
511+ dtype = self .config .dtype ,
512+ weight_dtype = self .config .weight_dtype ,
513+ rngs = self .rngs ,
514+ )
515+ self .key_norm = Qwen3NextRMSNorm (
516+ num_features = self .config .head_dim ,
517+ eps = self .config .normalization_layer_epsilon ,
518+ dtype = self .config .dtype ,
519+ weight_dtype = self .config .weight_dtype ,
520+ rngs = self .rngs ,
521+ )
504522
505523 self ._maybe_shard_with_logical = functools .partial (
506524 maybe_shard_with_logical ,
@@ -538,9 +556,15 @@ def query_init(*args):
538556 kernel_axes = (
539557 (None , None , None ) if self .config .ici_context_autoregressive_parallelism > 1 else ("embed" , "q_heads" , "kv" )
540558 )
559+ in_features = self .convert_dense_general_inputs_shape (inputs_q_shape )
560+ out_features = (self .num_query_heads , self .head_dim )
561+
562+ if self .is_qwen3_next :
563+ out_features = (self .num_query_heads , self .head_dim * 2 )
564+
541565 return DenseGeneral (
542- in_features_shape = self . convert_dense_general_inputs_shape ( inputs_q_shape ) ,
543- out_features_shape = ( self . num_query_heads , self . head_dim ) ,
566+ in_features_shape = in_features ,
567+ out_features_shape = out_features ,
544568 axis = - 1 ,
545569 kernel_init = query_init ,
546570 kernel_axes = kernel_axes ,
@@ -642,13 +666,22 @@ def qkv_projection(self, inputs: Array, proj_name: str, out_sharding: NamedShard
642666
643667 def init_out_w (self , output_dim : int ) -> nnx .Module :
644668 """out projection"""
669+ in_features = (self .num_query_heads , self .head_dim )
670+ out_features = output_dim
645671 out_kernel_axis = (
646672 (None , None , None ) if self .config .ici_context_autoregressive_parallelism > 1 else ("heads" , "kv" , "embed" )
647673 )
674+ axis = (- 2 , - 1 )
675+
676+ if self .is_qwen3_next :
677+ in_features = self .num_query_heads * self .head_dim
678+ out_kernel_axis = ("mlp" , "embed" )
679+ axis = (- 1 ,)
680+
648681 return DenseGeneral (
649- in_features_shape = ( self . num_query_heads , self . head_dim ) ,
650- out_features_shape = output_dim ,
651- axis = ( - 2 , - 1 ) ,
682+ in_features_shape = in_features ,
683+ out_features_shape = out_features ,
684+ axis = axis ,
652685 kernel_init = self .kernel_init ,
653686 kernel_axes = out_kernel_axis , # trade speed with memory
654687 dtype = self .dtype ,
@@ -720,6 +753,16 @@ def init_rotary_embedding(self):
720753 attention_scaling = self .config .rope_attention_scaling ,
721754 rngs = self .rngs ,
722755 )
756+ elif self .is_qwen3_next :
757+ rotary_embedding = Qwen3NextRotaryEmbedding (
758+ min_timescale = self .config .rope_min_timescale ,
759+ max_timescale = self .config .rope_max_timescale ,
760+ embedding_dims = self .config .head_dim ,
761+ partial_rotary_factor = self .config .partial_rotary_factor ,
762+ cast_as_fprop_dtype = True ,
763+ fprop_dtype = self .config .dtype ,
764+ rngs = self .rngs ,
765+ )
723766 else :
724767 max_timescale = self .config .rope_max_timescale
725768 # For local attention use local_rope_max_timescale if it's is positive
@@ -890,9 +933,17 @@ def __call__(
890933 value_sharding = NamedSharding (self .mesh , nn .logical_to_mesh_axes (self .value_axis_names ))
891934 value = self .kv_projection (inputs_kv , proj_name = "value" , out_sharding = value_sharding )
892935
936+ gate = None
937+ if self .is_qwen3_next :
938+ # Split query into query & gate.
939+ query , gate = jnp .split (query , 2 , axis = - 1 )
940+ batch_size , seq_len , _ , _ = gate .shape
941+ gate = gate .reshape (batch_size , seq_len , self .config .num_query_heads * self .config .head_dim )
942+
893943 is_llama4_decoder_block = self .config .decoder_block == DecoderBlockType .LLAMA4
894944 # NOTE: llama 4 does L2 normalization after RoPE
895- if self .use_qk_norm and not is_llama4_decoder_block :
945+ # Apply Qwen3Next specific RMS Norm
946+ if (self .use_qk_norm and not is_llama4_decoder_block ) or self .is_qwen3_next :
896947 query = self .query_norm (query )
897948 key = self .key_norm (key )
898949
@@ -964,7 +1015,9 @@ def __call__(
9641015 bidirectional_mask ,
9651016 self .sinks ,
9661017 )
967-
1018+ if self .is_qwen3_next :
1019+ out = out .reshape (batch_size , seq_len , self .config .num_query_heads * self .config .head_dim )
1020+ out = out * jax .nn .sigmoid (gate )
9681021 if model_mode == MODEL_MODE_PREFILL :
9691022 out = self ._maybe_shard_with_logical (out , self .prefill_out_axis_names )
9701023 elif model_mode == MODEL_MODE_TRAIN and self .config .expert_shard_attention_option == EP_AS_CONTEXT :
0 commit comments