@@ -54,6 +54,7 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
5454 llm_graph_context_mamba(params) {
5555 const int64_t n_embd_head = hparams.n_embd_head_v ;
5656 GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
57+ // GGML_ASSERT(n_embd_head == hparams.n_rot);
5758
5859 ggml_tensor * cur;
5960 ggml_tensor * inpL;
@@ -142,7 +143,8 @@ struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_t
142143 const llama_model & model,
143144 const int64_t n_embd_head,
144145 const int il) {
145- // compute Q and K and RoPE them
146+ // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
147+
146148 // Qwen3Next uses a single Q projection that outputs query + gate
147149 struct ggml_tensor * Qcur_full = build_lora_mm (model.layers [il].wq , cur);
148150 cb (Qcur_full, " Qcur_full" , il);
@@ -159,28 +161,28 @@ struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_t
159161 Qcur = ggml_cont_3d (ctx0, Qcur, n_embd_head, n_head, n_tokens);
160162 cb (Qcur, " Qcur_reshaped" , il);
161163
162- // Apply Q normalization only to the query part
164+ // Apply Q normalization
163165 Qcur = build_q3n_norm (Qcur, model.layers [il].attn_q_norm , il);
164166 cb (Qcur, " Qcur_normed" , il);
165-
166- // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
167- gate = ggml_cont_2d (ctx0, gate, n_embd_head * n_head, n_tokens);
168- cb (gate, " gate_reshaped" , il);
169167
170168 struct ggml_tensor * Kcur = build_lora_mm (model.layers [il].wk , cur);
171169 cb (Kcur, " Kcur" , il);
172170
173171 struct ggml_tensor * Vcur = build_lora_mm (model.layers [il].wv , cur);
174172 cb (Vcur, " Vcur" , il);
175173
174+ // Apply K normalization
175+ Kcur = build_q3n_norm (Kcur, model.layers [il].attn_k_norm , il);
176+ cb (Kcur, " Kcur_normed" , il);
177+
178+ // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
179+ gate = ggml_cont_2d (ctx0, gate, n_embd_head * n_head, n_tokens);
180+ cb (gate, " gate_reshaped" , il);
181+
176182 Qcur = ggml_cont_3d (ctx0, Qcur, n_embd_head, n_head, n_tokens);
177183 Kcur = ggml_reshape_3d (ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
178184 Vcur = ggml_reshape_3d (ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
179185
180- // Apply Q/K normalization
181- Kcur = build_q3n_norm (Kcur, model.layers [il].attn_k_norm , il);
182- cb (Kcur, " Kcur_normed" , il);
183-
184186 // Apply RoPE
185187 Qcur = ggml_rope_ext (ctx0, Qcur, inp_pos, nullptr , n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
186188 attn_factor, beta_fast, beta_slow);
0 commit comments