@@ -57,20 +57,29 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
5757 // Full attention layer
5858 cur = build_qwen3next_attention_layer (cur, inp_pos, inp->get_attn (), model, n_embd_head, il);
5959 }
60- // Post-attention norm
61- cur = build_q3n_norm (cur, model.layers [il].attn_post_norm , il);
62- cb (cur, " attn_post_norm" , il);
6360
6461 if (il == n_layer - 1 && inp_out_ids) {
6562 cur = ggml_get_rows (ctx0, cur, inp_out_ids);
6663 inpSA = ggml_get_rows (ctx0, inpSA, inp_out_ids);
6764 }
65+
6866 // Residual connection
6967 cur = ggml_add (ctx0, cur, inpSA);
7068 cb (cur, " attn_residual" , il);
7169
72- // FFN layer (MoE or dense)
73- cur = build_layer_ffn (cur, model, il);
70+ // Save the tensor before post-attention norm for residual connection
71+ ggml_tensor * ffn_residual = cur;
72+
73+ // Post-attention norm
74+ ggml_tensor * attn_post_norm = build_q3n_norm (cur, model.layers [il].attn_post_norm , il);
75+ cb (attn_post_norm, " attn_post_norm" , il);
76+
77+ // FFN layer (MoE or dense) - without residual connection
78+ cur = build_layer_ffn (attn_post_norm, model, il, false );
79+ cb (cur, " ffn_out" , il);
80+
81+ // Residual connection for FFN - add to the tensor BEFORE post_attention_layernorm
82+ cur = ggml_add (ctx0, cur, ffn_residual);
7483 cb (cur, " post_moe" , il);
7584
7685 // Input for next layer
@@ -111,26 +120,43 @@ struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_t
111120 const llama_model & model,
112121 const int64_t n_embd_head,
113122 const int il) {
114- ggml_tensor * gate = build_lora_mm (model.layers [il].wq_gate , cur);
115-
116123 // compute Q and K and RoPE them
117- struct ggml_tensor * Qcur = build_lora_mm (model.layers [il].wq , cur);
124+ // Qwen3Next uses a single Q projection that outputs query + gate
125+ struct ggml_tensor * Qcur_full = build_lora_mm (model.layers [il].wq , cur);
126+ cb (Qcur_full, " Qcur_full" , il);
127+ Qcur_full = ggml_reshape_4d (ctx0, Qcur_full, n_embd_head * 2 , n_head, n_tokens, 1 );
128+ // Split Q projection into query and gate
129+ // The split should be along dimension 0 (the feature dimension)
130+ struct ggml_tensor * Qcur = ggml_view_4d (ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1 , Qcur_full->nb [1 ], Qcur_full->nb [2 ], Qcur_full->nb [3 ], 0 );
131+ struct ggml_tensor * gate = ggml_view_4d (ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1 , Qcur_full->nb [1 ], Qcur_full->nb [2 ], Qcur_full->nb [3 ],
132+ n_embd_head * ggml_element_size (Qcur_full));
118133 cb (Qcur, " Qcur" , il);
134+ cb (gate, " gate" , il);
135+
136+ // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention
137+ Qcur = ggml_cont_3d (ctx0, Qcur, n_embd_head, n_head, n_tokens);
138+ cb (Qcur, " Qcur_reshaped" , il);
139+
140+ // Apply Q normalization only to the query part
141+ Qcur = build_q3n_norm (Qcur, model.layers [il].attn_q_norm , il);
142+ cb (Qcur, " Qcur_normed" , il);
143+
144+ // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
145+ gate = ggml_cont_2d (ctx0, gate, n_embd_head * n_head, n_tokens);
146+ cb (gate, " gate_reshaped" , il);
119147
120148 struct ggml_tensor * Kcur = build_lora_mm (model.layers [il].wk , cur);
121149 cb (Kcur, " Kcur" , il);
122150
123151 struct ggml_tensor * Vcur = build_lora_mm (model.layers [il].wv , cur);
124152 cb (Vcur, " Vcur" , il);
125153
126- Qcur = ggml_reshape_3d (ctx0, Qcur, n_embd_head, n_head, n_tokens);
154+ Qcur = ggml_cont_3d (ctx0, Qcur, n_embd_head, n_head, n_tokens);
127155 Kcur = ggml_reshape_3d (ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
128156 Vcur = ggml_reshape_3d (ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
129157
130158 // Apply Q/K normalization
131- Qcur = build_norm (Qcur, model.layers [il].attn_q_norm , NULL , LLM_NORM_RMS, il);
132- Kcur = build_norm (Kcur, model.layers [il].attn_k_norm , NULL , LLM_NORM_RMS, il);
133- cb (Kcur, " Qcur_normed" , il);
159+ Kcur = build_q3n_norm (Kcur, model.layers [il].attn_k_norm , il);
134160 cb (Kcur, " Kcur_normed" , il);
135161
136162 // Apply RoPE
@@ -149,8 +175,8 @@ struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_t
149175 hparams.f_attention_scale == 0 .0f ? 1 .0f / sqrtf (float (n_embd_head)) : hparams.f_attention_scale ;
150176 cur = build_attn (inp_attn, nullptr , nullptr , Qcur, Kcur, Vcur, nullptr , nullptr , nullptr , kq_scale, il);
151177
152- // Apply gating
153- cur = ggml_cont (ctx0, ggml_mul (ctx0, cur, ggml_sigmoid (ctx0, gate) ));
178+ // Apply gating directly using the original gate tensor
179+ cur = ggml_mul (ctx0, cur, ggml_sigmoid (ctx0, gate));
154180 cb (cur, " attn_gated" , il);
155181
156182 cur = build_lora_mm (model.layers [il].wo , cur);
@@ -598,7 +624,8 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
598624 return cur;
599625}
600626
601- ggml_tensor * llm_build_qwen3next::build_layer_ffn (ggml_tensor * cur, const llama_model & model, const int il) {
627+ ggml_tensor * llm_build_qwen3next::build_layer_ffn (ggml_tensor * cur, const llama_model & model, const int il, bool do_residual) {
628+
602629 // Check if this is an MoE layer
603630 if (model.layers [il].ffn_gate_inp != nullptr ) {
604631 // MoE branch
@@ -608,13 +635,33 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llam
608635 n_expert_used, LLM_FFN_SILU, true , false , 0.0 , LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
609636 cb (moe_out, " ffn_moe_out" , il);
610637
611- // Add shared experts if present
638+ // Add shared experts if present - following Qwen3Next reference implementation
612639 if (model.layers [il].ffn_up_shexp != nullptr ) {
613640 ggml_tensor * ffn_shexp =
614641 build_ffn (cur, model.layers [il].ffn_up_shexp , NULL , NULL , model.layers [il].ffn_gate_shexp , NULL , NULL ,
615642 model.layers [il].ffn_down_shexp , NULL , NULL , NULL , LLM_FFN_SILU, LLM_FFN_PAR, il);
616643 cb (ffn_shexp, " ffn_shexp" , il);
617644
645+ // Apply shared expert gating as in the reference implementation
646+ // The shared expert has its own gate that is sigmoided
647+ // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token)
648+ ggml_tensor * shared_gate = build_lora_mm (model.layers [il].ffn_gate_inp_shexp , cur);
649+ cb (shared_gate, " shared_expert_gate" , il);
650+
651+ // Apply sigmoid to the gate
652+ shared_gate = ggml_sigmoid (ctx0, shared_gate);
653+ cb (shared_gate, " shared_expert_gate_sigmoid" , il);
654+
655+ // The gate needs to be broadcast to match the dimensions of ffn_shexp
656+ // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1]
657+ // We need to repeat the gate along the feature dimension
658+ shared_gate = ggml_repeat (ctx0, shared_gate, ffn_shexp);
659+ cb (shared_gate, " shared_expert_gate_broadcast" , il);
660+
661+ // Apply the gate to the shared expert output
662+ ffn_shexp = ggml_mul (ctx0, ffn_shexp, shared_gate);
663+ cb (ffn_shexp, " ffn_shexp_gated" , il);
664+
618665 cur = ggml_add (ctx0, moe_out, ffn_shexp);
619666 cb (cur, " ffn_out" , il);
620667 } else {
@@ -626,9 +673,14 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llam
626673 model.layers [il].ffn_down , NULL , NULL , NULL , LLM_FFN_SILU, LLM_FFN_PAR, il);
627674 cb (cur, " ffn_out" , il);
628675 }
629- // Residual connection
630- cur = ggml_add (ctx0, cur, cur); // This should be the residual from before FFN
631- cb (cur, " ffn_residual" , il);
676+ // Residual connection (only if requested)
677+ if (do_residual) {
678+ cur = ggml_add (ctx0, cur, cur);
679+ cb (cur, " ffn_residual" , il);
680+ }
681+
682+ cur = build_cvec (cur, il);
683+ cb (cur, " l_out" , il);
632684
633685 return cur;
634686};
0 commit comments