Skip to content

Commit adcbd94

Browse files
committed
Linear layer output convergence
1 parent 666fc05 commit adcbd94

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

src/models/llm_build_qwen3next.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ struct ggml_tensor * llm_build_qwen3next::build_q3n_norm(struct ggml_tensor * in
9999
return build_norm(input, input_norm, nullptr, LLM_NORM_RMS, layer);
100100
}
101101

102+
struct ggml_tensor * llm_build_qwen3next::build_q3n_gated_norm(struct ggml_tensor * input, struct ggml_tensor * weights, struct ggml_tensor * gate, int layer) {
103+
ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer);
104+
ggml_tensor * gated_silu = ggml_silu(ctx0, gate);
105+
return ggml_mul(ctx0, normalized, gated_silu);
106+
}
107+
102108
struct ggml_tensor * llm_build_qwen3next::build_qwen3next_attention_layer(ggml_tensor * cur,
103109
ggml_tensor * inp_pos,
104110
llm_graph_input_attn_kv * inp_attn,
@@ -550,7 +556,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
550556
ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
551557
cb(attn_out_1d, "attn_out_1d", il);
552558

553-
ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_dim, n_tokens, n_heads, n_seqs);
559+
ggml_tensor * attn_out_final = ggml_cont(ctx0, ggml_permute(ctx0, ggml_cont_4d(ctx0, attn_out_1d, head_dim, n_tokens, n_heads, n_seqs), 0, 2, 1, 3));
554560
cb(attn_out_final, "attn_out_final", il);
555561

556562
// Extract the state part (second part of the concatenated tensor)
@@ -574,18 +580,10 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
574580
ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_tokens * n_seqs);
575581

576582
// Apply gated normalization: self.norm(core_attn_out, z)
577-
// This is Qwen3NextRMSNormGated which applies: RMSNorm(x) * silu(gate)
578-
ggml_tensor * attn_out_norm = build_norm(attn_out_2d_final, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
579-
cb(attn_out_norm, "attn_out_norm", il);
580-
581-
// Apply silu gate: attn_out_norm * silu(z_2d)
582-
ggml_tensor * z_silu = ggml_silu(ctx0, z_2d);
583-
cb(z_silu, "z_silu", il);
584-
ggml_tensor * gated_output = ggml_mul(ctx0, attn_out_norm, z_silu);
585-
cb(gated_output, "gated_output", il);
586-
583+
ggml_tensor * attn_out_norm = build_q3n_gated_norm(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
584+
587585
// Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
588-
ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, gated_output, head_dim, n_heads, n_tokens, n_seqs);
586+
ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, attn_out_norm, head_dim, n_heads, n_tokens, n_seqs);
589587

590588
// Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
591589
ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_tokens, n_seqs);

src/models/llm_build_qwen3next.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,6 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
4141
ggml_tensor * softplus(ggml_tensor * alpha, ggml_tensor * dt_bias);
4242

4343
ggml_tensor * build_q3n_norm(struct ggml_tensor * input, struct ggml_tensor * weights, int layer);
44+
ggml_tensor * build_q3n_gated_norm(struct ggml_tensor * input, struct ggml_tensor * weights, struct ggml_tensor * gate, int layer);
4445

4546
};

0 commit comments

Comments
 (0)