@@ -704,17 +704,27 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
704704 GGML_ASSERT (num_v_heads % num_k_heads == 0 );
705705 int64_t repeat_factor = num_v_heads / num_k_heads;
706706
707- // Step 1: Reshape to add a new dimension for the repeats
708- ggml_tensor * q_reshaped = ggml_reshape_4d (ctx0, q_conv, head_k_dim, num_k_heads, 1 , n_seq_tokens * n_seqs);
709- ggml_tensor * k_reshaped = ggml_reshape_4d (ctx0, k_conv, head_k_dim, num_k_heads, 1 , n_seq_tokens * n_seqs);
707+ // GGML tensor layout: [head_dim, num_heads, n_seq_tokens, n_seqs]
710708
711- // Step 2: Expand along the new dimension
712- ggml_tensor * q_expanded = ggml_repeat_4d (ctx0, q_reshaped , head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
713- ggml_tensor * k_expanded = ggml_repeat_4d (ctx0, k_reshaped , head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
709+ // Step 1: Flatten the sequence and batch dimensions to work with them more easily
710+ ggml_tensor * q_flat = ggml_reshape_2d (ctx0, q_conv , head_k_dim, num_k_heads * n_seq_tokens * n_seqs);
711+ ggml_tensor * k_flat = ggml_reshape_2d (ctx0, k_conv , head_k_dim, num_k_heads * n_seq_tokens * n_seqs);
714712
715- // Step 3: Reshape back to merge the repeated dimensions
716- q_conv = ggml_reshape_4d (ctx0, q_expanded, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
717- k_conv = ggml_reshape_4d (ctx0, k_expanded, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
713+ // Step 2: Reshape to prepare for repeat_interleave
714+ // From [head_dim, num_k_heads * n_seq_tokens * n_seqs]
715+ // To [head_dim, num_k_heads, 1, n_seq_tokens * n_seqs]
716+ ggml_tensor * q_reshaped = ggml_reshape_4d (ctx0, q_flat, head_k_dim, num_k_heads, 1 , n_seq_tokens * n_seqs);
717+ ggml_tensor * k_reshaped = ggml_reshape_4d (ctx0, k_flat, head_k_dim, num_k_heads, 1 , n_seq_tokens * n_seqs);
718+
719+ // Step 3: Repeat along the third dimension (the new dimension with size 1)
720+ ggml_tensor * q_repeated = ggml_repeat_4d (ctx0, q_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
721+ ggml_tensor * k_repeated = ggml_repeat_4d (ctx0, k_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
722+
723+ // Step 4: Reshape back to merge the head and repeat dimensions
724+ // From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs]
725+ // Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs]
726+ q_conv = ggml_reshape_4d (ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
727+ k_conv = ggml_reshape_4d (ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);
718728 }
719729
720730 cb (q_conv, " q_conv_predelta" , il);
0 commit comments