Skip to content

Commit a51d438

Browse files
committed
Like this.
1 parent 54bb6f1 commit a51d438

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

src/models/llm_build_qwen3next.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,9 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
279279
cb(q, "q_postscale", il);
280280
cb(beta, "beta_sigmoid", il);
281281

282-
q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3));
283-
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3));
284-
v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3));
282+
q = ggml_cont_4d(ctx, ggml_permute(ctx, ggml_reshape_4d(ctx, q, S_v, n_tokens, H_v, n_seqs), 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
283+
k = ggml_cont_4d(ctx, ggml_permute(ctx, ggml_reshape_4d(ctx, k, S_v, n_tokens, H_v, n_seqs), 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
284+
v = ggml_cont_4d(ctx, ggml_permute(ctx, ggml_reshape_4d(ctx, v, S_v, n_tokens, H_v, n_seqs), 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
285285

286286
q = ggml_pad(ctx, q, 0, pad_size, 0, 0);
287287
k = ggml_pad(ctx, k, 0, pad_size, 0, 0);
@@ -700,12 +700,13 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
700700
GGML_ASSERT(num_v_heads % num_k_heads == 0);
701701
int64_t repeat_factor = num_v_heads / num_k_heads;
702702

703-
ggml_tensor * q_reshaped = ggml_reshape_4d(ctx0, q_conv, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
704-
ggml_tensor * k_reshaped = ggml_reshape_4d(ctx0, k_conv, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
703+
// repeat interleave: reshape to (repeat part, 1, remaining part), do repeat, then reshape back
704+
ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs);
705+
ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs);
705706

706707
// Repeat along the third dimension (the new dimension with size 1)
707-
ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
708-
ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
708+
ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1);
709+
ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1);
709710

710711
// Reshape back to merge the head and repeat dimensions
711712
// From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs]

0 commit comments

Comments
 (0)