@@ -19240,7 +19240,6 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
1924019240 ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
1924119241 cb(mixed_ba, "linear_attn_mixed_ba", il);
1924219242
19243- // Reshape mixed_qkvz: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*head_k_dim + 2*head_v_dim*num_v_heads/num_k_heads]
1924419243 int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads;
1924519244 ggml_tensor * mixed_qkvz_reshaped =
1924619245 ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tokens, n_seqs);
@@ -19327,23 +19326,20 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
1932719326 // Build the convolution states tensor
1932819327 ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
1932919328 cb(conv_states, "conv_states", il);
19329+
19330+ // Combine query, key, value for convolution input
19331+ ggml_tensor * qkv_mixed = ggml_concat(ctx0, query, key, 1);
19332+ qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_reshaped, 1);
19333+
19334+ int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
1933019335
1933119336 // Calculate convolution kernel size
1933219337 const int64_t conv_kernel_size = model.layers[il].ssm_conv1d->ne[0];
19333-
19334- // Calculate input dimensions for Qwen3Next
19335- const int64_t input_dim = (head_k_dim * num_k_heads * 2) + (head_v_dim * num_v_heads);
19336-
19337- // Reshape conv_states to [conv_kernel_size - 1, input_dim, n_seqs]
19338- conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, input_dim, n_seqs);
19338+ conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state, n_seqs);
1933919339 cb(conv_states, "conv_states_reshaped", il);
1934019340
19341- // Combine query, key, value for convolution input
19342- ggml_tensor * qkv_mixed = ggml_concat(ctx0, query, key, 1);
19343- qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_reshaped, 1);
19344-
1934519341 // Reshape to [input_dim, n_seq_tokens, n_seqs] for concatenation
19346- qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, input_dim , n_seq_tokens, n_seqs);
19342+ qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_dim , n_seq_tokens, n_seqs);
1934719343 cb(qkv_mixed, "qkv_mixed_for_conv", il);
1934819344
1934919345 // Concatenate cached conv states with current input
@@ -19367,18 +19363,18 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
1936719363 // Update convolution state cache
1936819364 // Extract the last (conv_kernel_size - 1) states from conv_input
1936919365 ggml_tensor * last_conv_states =
19370- ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, input_dim , n_seqs, conv_input->nb[1],
19366+ ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, qkv_dim , n_seqs, conv_input->nb[1],
1937119367 conv_input->nb[2], n_seq_tokens * conv_input->nb[0]);
1937219368
1937319369 ggml_build_forward_expand(
1937419370 gf, ggml_cpy(ctx0, last_conv_states,
19375- ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * input_dim * n_seqs,
19376- mctx_cur->get_head() * (conv_kernel_size - 1) * input_dim *
19371+ ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * qkv_dim * n_seqs,
19372+ mctx_cur->get_head() * (conv_kernel_size - 1) * qkv_dim *
1937719373 ggml_element_size(conv_states_all))));
1937819374 cb(conv_states_all, "conv_states_updated", il);
1937919375
1938019376 // Reshape conv_output back to proper dimensions
19381- conv_output = ggml_reshape_4d(ctx0, conv_output, input_dim , n_seqs, n_seq_tokens, 1);
19377+ conv_output = ggml_reshape_4d(ctx0, conv_output, qkv_dim , n_seqs, n_seq_tokens, 1);
1938219378 cb(conv_output, "conv_output_reshaped", il);
1938319379 conv_output = ggml_permute(ctx0, conv_output, 0, 2, 1, 3);
1938419380 cb(conv_output, "conv_output_final", il);
0 commit comments