@@ -99,6 +99,12 @@ struct ggml_tensor * llm_build_qwen3next::build_q3n_norm(struct ggml_tensor * in
9999 return build_norm (input, input_norm, nullptr , LLM_NORM_RMS, layer);
100100}
101101
102+ struct ggml_tensor * llm_build_qwen3next::build_q3n_gated_norm (struct ggml_tensor * input, struct ggml_tensor * weights, struct ggml_tensor * gate, int layer) {
103+ ggml_tensor * normalized = build_norm (input, weights, nullptr , LLM_NORM_RMS, layer);
104+ ggml_tensor * gated_silu = ggml_silu (ctx0, gate);
105+ return ggml_mul (ctx0, normalized, gated_silu);
106+ }
107+
102108struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer (ggml_tensor * cur,
103109 ggml_tensor * inp_pos,
104110 llm_graph_input_attn_kv * inp_attn,
@@ -550,7 +556,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
550556 ggml_view_1d (ctx0, attn_out, output_flat_size, 0 );
551557 cb (attn_out_1d, " attn_out_1d" , il);
552558
553- ggml_tensor * attn_out_final = ggml_cont_4d (ctx0, attn_out_1d, head_dim, n_tokens, n_heads, n_seqs);
559+ ggml_tensor * attn_out_final = ggml_cont (ctx0, ggml_permute (ctx0, ggml_cont_4d (ctx0, attn_out_1d, head_dim, n_tokens, n_heads, n_seqs), 0 , 2 , 1 , 3 ) );
554560 cb (attn_out_final, " attn_out_final" , il);
555561
556562 // Extract the state part (second part of the concatenated tensor)
@@ -574,18 +580,10 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
574580 ggml_tensor * z_2d = ggml_reshape_2d (ctx0, z_reshaped, head_dim, n_heads * n_tokens * n_seqs);
575581
576582 // Apply gated normalization: self.norm(core_attn_out, z)
577- // This is Qwen3NextRMSNormGated which applies: RMSNorm(x) * silu(gate)
578- ggml_tensor * attn_out_norm = build_norm (attn_out_2d_final, model.layers [il].ssm_norm , NULL , LLM_NORM_RMS, il);
579- cb (attn_out_norm, " attn_out_norm" , il);
580-
581- // Apply silu gate: attn_out_norm * silu(z_2d)
582- ggml_tensor * z_silu = ggml_silu (ctx0, z_2d);
583- cb (z_silu, " z_silu" , il);
584- ggml_tensor * gated_output = ggml_mul (ctx0, attn_out_norm, z_silu);
585- cb (gated_output, " gated_output" , il);
586-
583+ ggml_tensor * attn_out_norm = build_q3n_gated_norm (attn_out_2d_final, model.layers [il].ssm_norm , z_2d, il);
584+
587585 // Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
588- ggml_tensor * gated_output_4d = ggml_reshape_4d (ctx0, gated_output , head_dim, n_heads, n_tokens, n_seqs);
586+ ggml_tensor * gated_output_4d = ggml_reshape_4d (ctx0, attn_out_norm , head_dim, n_heads, n_tokens, n_seqs);
589587
590588 // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
591589 ggml_tensor * final_output = ggml_reshape_3d (ctx0, gated_output_4d, n_heads * head_dim, n_tokens, n_seqs);
0 commit comments