Skip to content

Commit 4136521

Browse files
committed
attempt 2
1 parent c5dc442 commit 4136521

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

src/models/llm_build_qwen3next.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -704,17 +704,27 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
704704
GGML_ASSERT(num_v_heads % num_k_heads == 0);
705705
int64_t repeat_factor = num_v_heads / num_k_heads;
706706

707-
// Step 1: Reshape to add a new dimension for the repeats
708-
ggml_tensor * q_reshaped = ggml_reshape_4d(ctx0, q_conv, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
709-
ggml_tensor * k_reshaped = ggml_reshape_4d(ctx0, k_conv, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
707+
// GGML tensor layout: [head_dim, num_heads, n_seq_tokens, n_seqs]
710708

711-
// Step 2: Expand along the new dimension
712-
ggml_tensor * q_expanded = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
713-
ggml_tensor * k_expanded = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
709+
// Step 1: Flatten the sequence and batch dimensions to work with them more easily
710+
ggml_tensor * q_flat = ggml_reshape_2d(ctx0, q_conv, head_k_dim, num_k_heads * n_seq_tokens * n_seqs);
711+
ggml_tensor * k_flat = ggml_reshape_2d(ctx0, k_conv, head_k_dim, num_k_heads * n_seq_tokens * n_seqs);
714712

715-
// Step 3: Reshape back to merge the repeated dimensions
716-
q_conv = ggml_reshape_4d(ctx0, q_expanded, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
717-
k_conv = ggml_reshape_4d(ctx0, k_expanded, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
713+
// Step 2: Reshape to prepare for repeat_interleave
714+
// From [head_dim, num_k_heads * n_seq_tokens * n_seqs]
715+
// To [head_dim, num_k_heads, 1, n_seq_tokens * n_seqs]
716+
ggml_tensor * q_reshaped = ggml_reshape_4d(ctx0, q_flat, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
717+
ggml_tensor * k_reshaped = ggml_reshape_4d(ctx0, k_flat, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
718+
719+
// Step 3: Repeat along the third dimension (the new dimension with size 1)
720+
ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
721+
ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
722+
723+
// Step 4: Reshape back to merge the head and repeat dimensions
724+
// From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs]
725+
// Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs]
726+
q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
727+
k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
718728
}
719729

720730
cb(q_conv, "q_conv_predelta", il);

0 commit comments

Comments
 (0)