@@ -382,6 +382,144 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
382382 return result;
383383}
384384
385+ // delta_net_recurrent
386+ // Recurrent version of delta_net for sequence_length = 1
387+ struct ggml_tensor * llm_build_qwen3next::delta_net_recurrent (
388+ struct ggml_context * ctx,
389+ struct ggml_tensor * q,
390+ struct ggml_tensor * k,
391+ struct ggml_tensor * v,
392+ struct ggml_tensor * g,
393+ struct ggml_tensor * beta,
394+ struct ggml_tensor * state,
395+ bool use_qk_l2norm,
396+ float eps_norm,
397+ const int il
398+ ) {
399+ GGML_ASSERT (ggml_is_contiguous (q));
400+ GGML_ASSERT (ggml_is_contiguous (k));
401+ GGML_ASSERT (ggml_is_contiguous (v));
402+ GGML_ASSERT (ggml_is_contiguous (g));
403+ GGML_ASSERT (ggml_is_contiguous (beta));
404+ GGML_ASSERT (ggml_is_contiguous (state));
405+
406+ const int64_t S_k = q->ne [0 ];
407+ const int64_t H_k = q->ne [1 ];
408+ const int64_t n_tokens = q->ne [2 ];
409+ const int64_t n_seqs = q->ne [3 ];
410+
411+ const int64_t S_v = v->ne [0 ];
412+ const int64_t H_v = v->ne [1 ];
413+
414+ GGML_ASSERT (n_tokens == 1 ); // Recurrent version only supports sequence_length = 1
415+ GGML_ASSERT (v->ne [2 ] == n_tokens);
416+ GGML_ASSERT (k->ne [2 ] == n_tokens);
417+ GGML_ASSERT (g->ne [0 ] == H_v && g->ne [1 ] == n_tokens && g->ne [2 ] == n_seqs);
418+ GGML_ASSERT (beta->ne [0 ] == H_v && beta->ne [2 ] == n_tokens && beta->ne [3 ] == n_seqs);
419+ GGML_ASSERT (state->ne [0 ] == S_v && state->ne [1 ] == S_v * H_v && state->ne [2 ] == 1 && state->ne [3 ] == n_seqs);
420+
421+ GGML_ASSERT (q->ne [0 ] == S_k && q->ne [1 ] == H_k && q->ne [2 ] == n_tokens && q->ne [3 ] == n_seqs);
422+ GGML_ASSERT (k->ne [0 ] == S_k && k->ne [1 ] == H_k && k->ne [2 ] == n_tokens && q->ne [3 ] == n_seqs);
423+
424+ GGML_ASSERT (H_k == H_v); // we did a repeat to make sure this is the case
425+
426+ cb (q, " q_prenorm" , il);
427+ cb (k, " k_prenorm" , il);
428+
429+ if (use_qk_l2norm) {
430+ q = ggml_l2_norm (ctx, q, eps_norm);
431+ k = ggml_l2_norm (ctx, k, eps_norm);
432+ }
433+
434+ cb (k, " k_postnorm" , il);
435+ cb (q, " q_prescale" , il);
436+
437+ float scale = 1 .0f / sqrtf (S_v);
438+ q = ggml_scale (ctx, q, scale);
439+
440+ cb (beta, " beta_raw" , il);
441+ beta = ggml_sigmoid (ctx, beta);
442+
443+ cb (q, " q_postscale" , il);
444+ cb (beta, " beta_sigmoid" , il);
445+
446+ // Reshape tensors for recurrent computation
447+ // From [S_k, H_k, n_tokens, n_seqs] to [S_k, n_tokens, H_k, n_seqs]
448+ q = ggml_cont (ctx, ggml_permute (ctx, q, 0 , 2 , 1 , 3 ));
449+ cb (q, " q_reshape" , il);
450+ k = ggml_cont (ctx, ggml_permute (ctx, k, 0 , 2 , 1 , 3 ));
451+ cb (k, " k_reshape" , il);
452+ v = ggml_cont (ctx, ggml_permute (ctx, v, 0 , 2 , 1 , 3 ));
453+ cb (v, " v_reshape" , il);
454+
455+ beta = ggml_cont (ctx, ggml_permute (ctx, beta, 1 , 2 , 0 , 3 ));
456+ cb (beta, " beta_reshape" , il);
457+
458+ g = ggml_cont (ctx, ggml_permute (ctx, g, 2 , 0 , 3 , 1 ));
459+ cb (g, " g_permute" , il);
460+
461+ ggml_tensor * q_t = ggml_cont_4d (ctx, q, 1 , S_k, H_k, n_seqs);
462+ ggml_tensor * k_t = ggml_cont_4d (ctx, k, 1 , S_k, H_k, n_seqs);
463+ ggml_tensor * v_t = ggml_cont_4d (ctx, v, 1 , S_v, H_k, n_seqs);
464+ ggml_tensor * g_t = ggml_cont_4d (ctx, g, 1 , 1 , H_k, n_seqs);
465+ ggml_tensor * beta_t = ggml_cont_4d (ctx, beta, 1 , 1 , H_k, n_seqs);
466+ state = ggml_cont_4d (ctx, state, S_v, S_v, H_k, n_seqs);
467+
468+ // Apply exponential to gate: exp(g)
469+ ggml_tensor * g_exp = ggml_exp (ctx, g_t );
470+ cb (g_exp, " g_exp" , il);
471+
472+ // Apply gate to state: state = state * exp(g)
473+ ggml_tensor * gated_state = ggml_mul (ctx, state, g_exp);
474+ cb (gated_state, " gated_state" , il);
475+
476+ // Compute kv_memory from state and key
477+ // kv_mem = (state * k.unsqueeze(-1)).sum(dim=-2)
478+
479+ // Reshape gated_state from [S_v, S_v*H_v, 1, n_seqs] to [S_v, S_v, H_v, n_seqs]
480+ // to make it compatible with k_expanded for element-wise multiplication
481+ ggml_tensor * gated_state_reshaped = ggml_reshape_4d (ctx, gated_state, S_v, S_v, H_v, n_seqs);
482+ cb (gated_state_reshaped, " gated_state_reshaped" , il);
483+
484+ ggml_tensor * state_k_product = ggml_mul (ctx, gated_state_reshaped, k_t );
485+ cb (state_k_product, " state_k_product" , il);
486+
487+ ggml_tensor * kv_memory = ggml_sum_rows (ctx, ggml_cont (ctx, ggml_transpose (ctx, state_k_product)));
488+ cb (kv_memory, " kv_memory" , il);
489+
490+ // Compute delta = (v - kv_memory) * beta
491+ ggml_tensor * v_diff = ggml_sub (ctx, v_t , kv_memory);
492+ ggml_tensor * delta = ggml_mul (ctx, v_diff, beta_t );
493+ cb (delta, " delta" , il);
494+
495+ // Update state = state + k * delta
496+ // In the reference: last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
497+ ggml_tensor * delta_t = ggml_transpose (ctx, delta);
498+
499+ // Will need to broadcast here since GGML doesn't support auto-double-broadcasting on mul
500+ ggml_tensor * delta_t_broadcast = ggml_repeat_4d (ctx, delta_t , S_v, S_v, H_v, n_seqs);
501+ ggml_tensor * k_t_broadcast = ggml_repeat_4d (ctx, k_t , S_v, S_v, H_v, n_seqs);
502+ ggml_tensor * k_delta_product = ggml_mul (ctx, k_t_broadcast, delta_t_broadcast);
503+ cb (k_delta_product, " k_delta" , il);
504+
505+ ggml_tensor * updated_state = ggml_add (ctx, gated_state_reshaped, k_delta_product);
506+ cb (updated_state, " updated_state" , il);
507+
508+ ggml_tensor * state_q_product = ggml_mul (ctx, updated_state, q_t );
509+ cb (state_q_product, " state_q_product" , il);
510+ ggml_tensor * output = ggml_sum_rows (ctx, ggml_cont (ctx, ggml_transpose (ctx, state_q_product)));
511+ cb (output, " output" , il);
512+
513+ // Concatenate output and updated_state into a single tensor
514+ // First, flatten both tensors to 1D
515+ ggml_tensor * output_1d = ggml_cont_1d (ctx, output, ggml_nelements (output));
516+ ggml_tensor * updated_state_1d = ggml_cont_1d (ctx, updated_state, ggml_nelements (updated_state));
517+
518+ // Concatenate them: [output, updated_state]
519+ ggml_tensor * result = ggml_concat (ctx, output_1d, updated_state_1d, 0 );
520+ return result;
521+ }
522+
385523
386524ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer (llm_graph_input_rs * inp,
387525 ggml_tensor * cur,
@@ -402,6 +540,8 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
402540
403541 const int64_t n_seq_tokens = ubatch.n_seq_tokens ;
404542
543+ const auto kv_head = mctx_cur->get_head ();
544+
405545 GGML_ASSERT (n_seqs != 0 );
406546 GGML_ASSERT (ubatch.equal_seqs ());
407547 GGML_ASSERT (ubatch.n_tokens == n_seq_tokens * n_seqs);
@@ -494,6 +634,8 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
494634 ggml_tensor * conv_states_all = mctx_cur->get_r_l (il);
495635 ggml_tensor * ssm_states_all = mctx_cur->get_s_l (il);
496636
637+ bool is_generation = mctx_cur->get_rs_z () < 0 ;
638+
497639 // Build the convolution states tensor
498640 ggml_tensor * conv_states = build_rs (inp, conv_states_all, hparams.n_embd_r (), n_seqs);
499641 cb (conv_states, " conv_states" , il);
@@ -528,7 +670,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
528670 cb (last_conv_states, " last_conv_states" , il);
529671
530672 ggml_tensor * state_update_target = ggml_view_1d (ctx0, conv_states_all, (conv_kernel_size - 1 ) * conv_channels * n_seqs,
531- mctx_cur-> get_head () * (conv_kernel_size - 1 ) * conv_channels * ggml_element_size (conv_states_all));
673+ kv_head * (conv_kernel_size - 1 ) * conv_channels * ggml_element_size (conv_states_all));
532674 cb (state_update_target, " state_update_target" , il);
533675
534676 ggml_build_forward_expand (gf, ggml_cpy (ctx0, last_conv_states, state_update_target));
@@ -584,6 +726,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
584726
585727 ggml_tensor * state = build_rs (inp, ssm_states_all, hparams.n_embd_s (), n_seqs);
586728 state = ggml_reshape_4d (ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1 , n_seqs);
729+ cb (state, " state_predelta" , il);
587730
588731 // if head keys and value keys are different, repeat to force tensors into matching shapes
589732 if (num_k_heads != num_v_heads) {
@@ -598,8 +741,15 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
598741 cb (k_conv, " k_conv_predelta" , il);
599742 cb (v_conv, " v_conv_predelta" , il);
600743
601- // Call the new delta_net function with the corrected flow
602- ggml_tensor * attn_out = delta_net (ctx0, q_conv, k_conv, v_conv, gate, beta, state, true , hparams.f_norm_rms_eps , il);
744+ // Choose between delta_net and delta_net_recurrent based on generation mode
745+ ggml_tensor * attn_out;
746+ if (is_generation) {
747+ // Use delta_net_recurrent for single token generation
748+ attn_out = delta_net_recurrent (ctx0, q_conv, k_conv, v_conv, gate, beta, state, true , hparams.f_norm_rms_eps , il);
749+ } else {
750+ // Use regular delta_net for prompt processing
751+ attn_out = delta_net (ctx0, q_conv, k_conv, v_conv, gate, beta, state, true , hparams.f_norm_rms_eps , il);
752+ }
603753 cb (attn_out, " attn_out" , il);
604754
605755 // The tensors were concatenated 1d, so we need to extract them 1d as well
@@ -621,7 +771,9 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
621771 // Update the recurrent states
622772 ggml_build_forward_expand (gf,
623773 ggml_cpy (ctx0, state_1d, ggml_view_1d (ctx0, ssm_states_all, hparams.n_embd_s () * n_seqs,
624- hparams.n_embd_s () * mctx_cur->get_head () * ggml_element_size (ssm_states_all))));
774+ kv_head * hparams.n_embd_s () * ggml_element_size (ssm_states_all))));
775+
776+ GGML_ASSERT (ggml_nelements (attn_out_1d) + ggml_nelements (state_1d) == ggml_nelements (attn_out));
625777
626778 // Reshape both attn_out_final and z to 2D tensors for normalization
627779 // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
0 commit comments