Skip to content

Commit 232ec56

Browse files
committed
Yes, I finally managed to implement it with ssm_conv :>
1 parent aa8d6a2 commit 232ec56

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

src/models/llm_build_qwen3next.cpp

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,10 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
501501
// Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
502502
ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
503503
qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
504+
cb(qkv_mixed, "qkv_mixed", il);
505+
504506
qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
505-
cb(qkv_mixed, "qkv_mixed_concatenated", il);
507+
cb(qkv_mixed, "qkv_mixed_permuted", il);
506508

507509
// Calculate the total conv dimension
508510
int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
@@ -511,7 +513,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
511513
ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
512514
const int64_t conv_kernel_size = conv_kernel->ne[0];
513515
const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
514-
conv_kernel = ggml_permute(ctx0, conv_kernel, 0, 2, 1, 3);
516+
//conv_kernel = ggml_permute(ctx0, conv_kernel, 0, 2, 1, 3);
515517
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
516518
cb(conv_states, "conv_states_reshaped", il);
517519

@@ -522,29 +524,32 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
522524
// Extract the last (conv_kernel_size - 1) states from conv_input
523525
ggml_tensor * last_conv_states =
524526
ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], conv_input->nb[2],
525-
n_seq_tokens * (conv_input->nb[0]));
527+
(conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
528+
cb(last_conv_states, "last_conv_states", il);
529+
530+
ggml_tensor * state_update_target = ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
531+
mctx_cur->get_head() * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
532+
cb(state_update_target, "state_update_target", il);
526533

527-
ggml_build_forward_expand(gf,
528-
ggml_cpy(ctx0, last_conv_states, ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
529-
mctx_cur->get_head() * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all))));
534+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
530535
cb(conv_states_all, "conv_states_updated", il);
531536

532537
// Apply convolution
533-
ggml_tensor * conv_output = ggml_conv_1d_dw_f32(ctx0, conv_kernel, conv_input, 1, conv_kernel_size - 1, 1);
534-
cb(conv_output, "conv_output_raw", il);
538+
ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); //ggml_conv_1d_dw_f32(ctx0, conv_kernel, conv_input, 1, conv_kernel_size - 1, 1);
539+
cb(conv_output_proper, "conv_output_raw", il);
535540

536541
// Remove the padding
537-
ggml_tensor * conv_output_no_padding = ggml_view_4d(ctx0, conv_output, conv_output->ne[0] - (conv_kernel_size - 1), conv_output->ne[1], conv_output->ne[2], conv_output->ne[3],
538-
conv_output->nb[1], conv_output->nb[2], conv_output->nb[3],
539-
(conv_kernel_size - 1) * ggml_element_size(conv_output));
540-
cb(conv_output_no_padding, "conv_output_no_padding", il);
542+
// ggml_tensor * conv_output_no_padding = ggml_view_4d(ctx0, conv_output, conv_output->ne[0] - (conv_kernel_size - 1), conv_output->ne[1], conv_output->ne[2], conv_output->ne[3],
543+
// conv_output->nb[1], conv_output->nb[2], conv_output->nb[3],
544+
// (conv_kernel_size - 1) * ggml_element_size(conv_output));
545+
// cb(conv_output_no_padding, "conv_output_no_padding", il);
541546

542-
// Take only the first n_seq_tokens values
543-
ggml_tensor * conv_output_proper = ggml_view_4d(ctx0, conv_output_no_padding, n_seq_tokens, conv_output_no_padding->ne[1], conv_output_no_padding->ne[2], conv_output_no_padding->ne[3],
544-
conv_output_no_padding->nb[1], conv_output_no_padding->nb[2], conv_output_no_padding->nb[3], 0);
545-
cb(conv_output_proper, "conv_output_proper", il);
547+
// // Take only the first n_seq_tokens values
548+
// ggml_tensor * conv_output_proper = ggml_view_4d(ctx0, conv_output_no_padding, n_seq_tokens, conv_output_no_padding->ne[1], conv_output_no_padding->ne[2], conv_output_no_padding->ne[3],
549+
// conv_output_no_padding->nb[1], conv_output_no_padding->nb[2], conv_output_no_padding->nb[3], 0);
550+
// cb(conv_output_proper, "conv_output_proper", il);
546551

547-
conv_output_proper = ggml_permute(ctx0, conv_output_proper, 0, 1, 3, 2);
552+
conv_output_proper = ggml_transpose(ctx0, conv_output_proper);
548553
conv_output_proper = ggml_cont_4d(ctx0, conv_output_proper, qkv_dim, 1, n_seq_tokens, n_seqs);
549554

550555
ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);

0 commit comments

Comments
 (0)