@@ -501,8 +501,10 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
501501 // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
502502 ggml_tensor * qkv_mixed = ggml_concat (ctx0, query_flat, key_flat, 0 );
503503 qkv_mixed = ggml_concat (ctx0, qkv_mixed, value_flat, 0 );
504+ cb (qkv_mixed, " qkv_mixed" , il);
505+
504506 qkv_mixed = ggml_permute (ctx0, qkv_mixed, 1 , 0 , 2 , 3 );
505- cb (qkv_mixed, " qkv_mixed_concatenated " , il);
507+ cb (qkv_mixed, " qkv_mixed_permuted " , il);
506508
507509 // Calculate the total conv dimension
508510 int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
@@ -511,7 +513,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
511513 ggml_tensor * conv_kernel = model.layers [il].ssm_conv1d ;
512514 const int64_t conv_kernel_size = conv_kernel->ne [0 ];
513515 const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state ;
514- conv_kernel = ggml_permute (ctx0, conv_kernel, 0 , 2 , 1 , 3 );
516+ // conv_kernel = ggml_permute(ctx0, conv_kernel, 0, 2, 1, 3);
515517 conv_states = ggml_reshape_3d (ctx0, conv_states, conv_kernel_size - 1 , conv_channels, n_seqs);
516518 cb (conv_states, " conv_states_reshaped" , il);
517519
@@ -522,29 +524,32 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
522524 // Extract the last (conv_kernel_size - 1) states from conv_input
523525 ggml_tensor * last_conv_states =
524526 ggml_view_3d (ctx0, conv_input, conv_kernel_size - 1 , conv_channels, n_seqs, conv_input->nb [1 ], conv_input->nb [2 ],
525- n_seq_tokens * (conv_input->nb [0 ]));
527+ (conv_input->ne [0 ] - conv_states->ne [0 ]) * ggml_element_size (conv_input));
528+ cb (last_conv_states, " last_conv_states" , il);
529+
530+ ggml_tensor * state_update_target = ggml_view_1d (ctx0, conv_states_all, (conv_kernel_size - 1 ) * conv_channels * n_seqs,
531+ mctx_cur->get_head () * (conv_kernel_size - 1 ) * conv_channels * ggml_element_size (conv_states_all));
532+ cb (state_update_target, " state_update_target" , il);
526533
527- ggml_build_forward_expand (gf,
528- ggml_cpy (ctx0, last_conv_states, ggml_view_1d (ctx0, conv_states_all, (conv_kernel_size - 1 ) * conv_channels * n_seqs,
529- mctx_cur->get_head () * (conv_kernel_size - 1 ) * conv_channels * ggml_element_size (conv_states_all))));
534+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, last_conv_states, state_update_target));
530535 cb (conv_states_all, " conv_states_updated" , il);
531536
532537 // Apply convolution
533- ggml_tensor * conv_output = ggml_conv_1d_dw_f32 (ctx0, conv_kernel, conv_input, 1 , conv_kernel_size - 1 , 1 );
534- cb (conv_output , " conv_output_raw" , il);
538+ ggml_tensor * conv_output_proper = ggml_ssm_conv (ctx0, conv_input, conv_kernel); // ggml_conv_1d_dw_f32(ctx0, conv_kernel, conv_input, 1, conv_kernel_size - 1, 1);
539+ cb (conv_output_proper , " conv_output_raw" , il);
535540
536541 // Remove the padding
537- ggml_tensor * conv_output_no_padding = ggml_view_4d (ctx0, conv_output, conv_output->ne [0 ] - (conv_kernel_size - 1 ), conv_output->ne [1 ], conv_output->ne [2 ], conv_output->ne [3 ],
538- conv_output->nb [1 ], conv_output->nb [2 ], conv_output->nb [3 ],
539- (conv_kernel_size - 1 ) * ggml_element_size (conv_output));
540- cb (conv_output_no_padding, " conv_output_no_padding" , il);
542+ // ggml_tensor * conv_output_no_padding = ggml_view_4d(ctx0, conv_output, conv_output->ne[0] - (conv_kernel_size - 1), conv_output->ne[1], conv_output->ne[2], conv_output->ne[3],
543+ // conv_output->nb[1], conv_output->nb[2], conv_output->nb[3],
544+ // (conv_kernel_size - 1) * ggml_element_size(conv_output));
545+ // cb(conv_output_no_padding, "conv_output_no_padding", il);
541546
542- // Take only the first n_seq_tokens values
543- ggml_tensor * conv_output_proper = ggml_view_4d (ctx0, conv_output_no_padding, n_seq_tokens, conv_output_no_padding->ne [1 ], conv_output_no_padding->ne [2 ], conv_output_no_padding->ne [3 ],
544- conv_output_no_padding->nb [1 ], conv_output_no_padding->nb [2 ], conv_output_no_padding->nb [3 ], 0 );
545- cb (conv_output_proper, " conv_output_proper" , il);
547+ // // Take only the first n_seq_tokens values
548+ // ggml_tensor * conv_output_proper = ggml_view_4d(ctx0, conv_output_no_padding, n_seq_tokens, conv_output_no_padding->ne[1], conv_output_no_padding->ne[2], conv_output_no_padding->ne[3],
549+ // conv_output_no_padding->nb[1], conv_output_no_padding->nb[2], conv_output_no_padding->nb[3], 0);
550+ // cb(conv_output_proper, "conv_output_proper", il);
546551
547- conv_output_proper = ggml_permute (ctx0, conv_output_proper, 0 , 1 , 3 , 2 );
552+ conv_output_proper = ggml_transpose (ctx0, conv_output_proper);
548553 conv_output_proper = ggml_cont_4d (ctx0, conv_output_proper, qkv_dim, 1 , n_seq_tokens, n_seqs);
549554
550555 ggml_tensor * conv_output_silu = ggml_silu (ctx0, conv_output_proper);
0 commit comments