Skip to content

Commit c5dc442

Browse files
committed
repeat_interleave
1 parent a4fe128 commit c5dc442

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

src/models/llm_build_qwen3next.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -704,8 +704,17 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
704704
GGML_ASSERT(num_v_heads % num_k_heads == 0);
705705
int64_t repeat_factor = num_v_heads / num_k_heads;
706706

707-
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
708-
k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
707+
// Step 1: Reshape to add a new dimension for the repeats
708+
ggml_tensor * q_reshaped = ggml_reshape_4d(ctx0, q_conv, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
709+
ggml_tensor * k_reshaped = ggml_reshape_4d(ctx0, k_conv, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
710+
711+
// Step 2: Expand along the new dimension
712+
ggml_tensor * q_expanded = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
713+
ggml_tensor * k_expanded = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
714+
715+
// Step 3: Reshape back to merge the repeated dimensions
716+
q_conv = ggml_reshape_4d(ctx0, q_expanded, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
717+
k_conv = ggml_reshape_4d(ctx0, k_expanded, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
709718
}
710719

711720
cb(q_conv, "q_conv_predelta", il);

0 commit comments

Comments
 (0)