@@ -98,7 +98,7 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
9898 cb (attn_post_norm, " attn_post_norm" , il);
9999
100100 // FFN layer (MoE or dense) - without residual connection
101- cur = build_layer_ffn (attn_post_norm, model, il, false );
101+ cur = build_layer_ffn (attn_post_norm, model, il);
102102 cb (cur, " ffn_out" , il);
103103
104104 // Residual connection for FFN - add to the tensor BEFORE post_attention_layernorm
@@ -120,7 +120,6 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
120120 cur = build_lora_mm (model.output , cur);
121121
122122 cb (cur, " result_output" , -1 );
123- ggml_set_output (cur);
124123 res->t_logits = cur;
125124
126125 ggml_build_forward_expand (gf, cur);
@@ -511,13 +510,25 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
511510 // Calculate convolution kernel size
512511 ggml_tensor * conv_kernel = model.layers [il].ssm_conv1d ;
513512 const int64_t conv_kernel_size = conv_kernel->ne [0 ];
513+ const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state ;
514514 conv_kernel = ggml_permute (ctx0, conv_kernel, 0 , 2 , 1 , 3 );
515- conv_states = ggml_reshape_3d (ctx0, conv_states, conv_kernel_size - 1 , d_inner + 2 * hparams. ssm_n_group * hparams. ssm_d_state , n_seqs);
515+ conv_states = ggml_reshape_3d (ctx0, conv_states, conv_kernel_size - 1 , conv_channels , n_seqs);
516516 cb (conv_states, " conv_states_reshaped" , il);
517517
518518 ggml_tensor * conv_input = ggml_concat (ctx0, conv_states, qkv_mixed, 0 );
519519 cb (conv_input, " conv_input" , il);
520520
521+ // Update convolution state cache
522+ // Extract the last (conv_kernel_size - 1) states from conv_input
523+ ggml_tensor * last_conv_states =
524+ ggml_view_3d (ctx0, conv_input, conv_kernel_size - 1 , conv_channels, n_seqs, conv_input->nb [1 ], conv_input->nb [2 ],
525+ n_seq_tokens * (conv_input->nb [0 ]));
526+
527+ ggml_build_forward_expand (gf,
528+ ggml_cpy (ctx0, last_conv_states, ggml_view_1d (ctx0, conv_states_all, (conv_kernel_size - 1 ) * conv_channels * n_seqs,
529+ mctx_cur->get_head () * (conv_kernel_size - 1 ) * conv_channels * ggml_element_size (conv_states_all))));
530+ cb (conv_states_all, " conv_states_updated" , il);
531+
521532 // Apply convolution
522533 ggml_tensor * conv_output = ggml_conv_1d_dw_f32 (ctx0, conv_kernel, conv_input, 1 , conv_kernel_size - 1 , 1 );
523534 cb (conv_output, " conv_output_raw" , il);
@@ -539,19 +550,6 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
539550 ggml_tensor * conv_output_silu = ggml_silu (ctx0, conv_output_proper);
540551 cb (conv_output_silu, " conv_output_silu" , il);
541552
542- // Update convolution state cache
543- // Extract the last (conv_kernel_size - 1) states from conv_input
544- ggml_tensor * last_conv_states =
545- ggml_view_3d (ctx0, conv_input, conv_kernel_size - 1 , qkv_dim, n_seqs, conv_input->nb [1 ], conv_input->nb [2 ],
546- n_seq_tokens * conv_input->nb [0 ]);
547-
548- ggml_build_forward_expand (gf,
549- ggml_cpy (ctx0, last_conv_states,
550- ggml_view_1d (ctx0, conv_states_all, (conv_kernel_size - 1 ) * qkv_dim * n_seqs,
551- mctx_cur->get_head () * (conv_kernel_size - 1 ) * qkv_dim *
552- ggml_element_size (conv_states_all))));
553- cb (conv_states_all, " conv_states_updated" , il);
554-
555553 conv_output_proper = ggml_reshape_2d (ctx0, conv_output_silu, n_seq_tokens * n_seqs, qkv_dim);
556554 cb (conv_output_proper, " conv_output_final" , il);
557555
@@ -615,12 +613,10 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
615613 ggml_tensor * state_1d = ggml_view_1d (ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size (attn_out));
616614 cb (state_1d, " state_1d" , il);
617615
618- ggml_tensor * new_state = ggml_reshape_4d (ctx0, state_1d, head_dim, head_dim, n_heads, n_seqs);
619- cb (new_state, " new_state" , il);
620-
621616 // Update the recurrent states
622- ggml_build_forward_expand (gf, ggml_view_1d (ctx0, mctx_cur->get_s_l (il), hparams.n_embd_s () * n_seqs,
623- hparams.n_embd_s () * mctx_cur->get_head () * ggml_element_size (mctx_cur->get_s_l (il))));
617+ ggml_build_forward_expand (gf,
618+ ggml_cpy (ctx0, state_1d, ggml_view_1d (ctx0, ssm_states_all, hparams.n_embd_s () * n_seqs,
619+ hparams.n_embd_s () * mctx_cur->get_head () * ggml_element_size (ssm_states_all))));
624620
625621 // Reshape both attn_out_final and z to 2D tensors for normalization
626622 // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
@@ -648,7 +644,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
648644 return cur;
649645}
650646
651- ggml_tensor * llm_build_qwen3next::build_layer_ffn (ggml_tensor * cur, const llama_model & model, const int il, bool do_residual ) {
647+ ggml_tensor * llm_build_qwen3next::build_layer_ffn (ggml_tensor * cur, const llama_model & model, const int il) {
652648
653649 // Check if this is an MoE layer
654650 if (model.layers [il].ffn_gate_inp != nullptr ) {
@@ -697,15 +693,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const llam
697693 model.layers [il].ffn_down , NULL , NULL , NULL , LLM_FFN_SILU, LLM_FFN_PAR, il);
698694 cb (cur, " ffn_out" , il);
699695 }
700- // Residual connection (only if requested)
701- if (do_residual) {
702- cur = ggml_add (ctx0, cur, cur);
703- cb (cur, " ffn_residual" , il);
704- }
705-
706- cur = build_cvec (cur, il);
707- cb (cur, " l_out" , il);
708-
709696 return cur;
710697};
711698
0 commit comments