@@ -412,7 +412,6 @@ struct ggml_tensor * llm_build_qwen3next::delta_net_recurrent(
412412 const int64_t S_v = v->ne [0 ];
413413 const int64_t H_v = v->ne [1 ];
414414
415- GGML_ASSERT (n_tokens == 1 ); // Recurrent version only supports sequence_length = 1
416415 GGML_ASSERT (v->ne [2 ] == n_tokens);
417416 GGML_ASSERT (k->ne [2 ] == n_tokens);
418417 GGML_ASSERT (g->ne [0 ] == H_v && g->ne [1 ] == n_tokens && g->ne [2 ] == n_seqs);
@@ -459,62 +458,85 @@ struct ggml_tensor * llm_build_qwen3next::delta_net_recurrent(
459458 g = ggml_cont (ctx, ggml_permute (ctx, g, 2 , 0 , 3 , 1 ));
460459 cb (g, " g_permute" , il);
461460
462- ggml_tensor * q_t = ggml_cont_4d (ctx, q, 1 , S_k, H_k, n_seqs);
463- ggml_tensor * k_t = ggml_cont_4d (ctx, k, 1 , S_k, H_k, n_seqs);
464- ggml_tensor * v_t = ggml_cont_4d (ctx, v, 1 , S_v, H_k, n_seqs);
465- ggml_tensor * g_t = ggml_cont_4d (ctx, g, 1 , 1 , H_k, n_seqs);
466- ggml_tensor * beta_t = ggml_cont_4d (ctx, beta, 1 , 1 , H_k, n_seqs);
461+ ggml_tensor * q_tokens = ggml_cont_4d (ctx, q, n_tokens, S_k, H_k, n_seqs);
462+ ggml_tensor * k_tokens = ggml_cont_4d (ctx, k, n_tokens, S_k, H_k, n_seqs);
463+ ggml_tensor * v_tokens = ggml_cont_4d (ctx, v, n_tokens, S_v, H_k, n_seqs);
464+ ggml_tensor * g_tokens = ggml_cont_4d (ctx, g, n_tokens, 1 , H_k, n_seqs);
465+ ggml_tensor * beta_tokens = ggml_cont_4d (ctx, beta, n_tokens, 1 , H_k, n_seqs);
466+
467467 state = ggml_cont_4d (ctx, state, S_v, S_v, H_k, n_seqs);
468+ ggml_tensor * g_tokens_exp = ggml_exp (ctx, g_tokens);
469+
470+ ggml_tensor * final_output = nullptr ;
471+ ggml_tensor * q_t , * k_t , * v_t , * g_t_exp, * beta_t ;
472+ for (int i = 0 ; i < n_tokens; i++) { // this part is per token
473+ if (n_tokens == 1 ) { // don't do unnecessary reshapes / views
474+ q_t = q_tokens;
475+ k_t = k_tokens;
476+ v_t = v_tokens;
477+ g_t_exp = g_tokens_exp;
478+ beta_t = beta_tokens;
479+ } else {
480+ q_t = ggml_view_4d (ctx, q_tokens, 1 , S_k, H_k, n_seqs, q_tokens->nb [1 ], q_tokens->nb [2 ], q_tokens->nb [3 ], i * ggml_element_size (q_tokens));
481+ k_t = ggml_view_4d (ctx, k_tokens, 1 , S_k, H_k, n_seqs, k_tokens->nb [1 ], k_tokens->nb [2 ], k_tokens->nb [3 ], i * ggml_element_size (k_tokens));
482+ v_t = ggml_view_4d (ctx, v_tokens, 1 , S_v, H_k, n_seqs, v_tokens->nb [1 ], v_tokens->nb [2 ], v_tokens->nb [3 ], i * ggml_element_size (v_tokens));
483+ g_t_exp = ggml_view_4d (ctx, g_tokens_exp, 1 , 1 , H_k, n_seqs, g_tokens_exp->nb [1 ], g_tokens_exp->nb [2 ], g_tokens_exp->nb [3 ], i * ggml_element_size (g_tokens_exp));
484+ beta_t = ggml_view_4d (ctx, beta_tokens, 1 , 1 , H_k, n_seqs, beta_tokens->nb [1 ], beta_tokens->nb [2 ], beta_tokens->nb [3 ], i * ggml_element_size (beta_tokens));
485+ }
468486
469- // Apply exponential to gate: exp(g)
470- ggml_tensor * g_exp = ggml_exp (ctx, g_t );
471- cb (g_exp , " g_exp " , il);
487+ // Apply gate to state: state = state * exp(g)
488+ ggml_tensor * gated_state = ggml_mul (ctx, state, g_t_exp );
489+ cb (gated_state , " gated_state " , il);
472490
473- // Apply gate to state: state = state * exp(g)
474- ggml_tensor * gated_state = ggml_mul (ctx, state, g_exp);
475- cb (gated_state, " gated_state" , il);
491+ // Compute kv_memory from state and key
492+ // kv_mem = (state * k.unsqueeze(-1)).sum(dim=-2)
493+
494+ // Reshape gated_state from [S_v, S_v*H_v, 1, n_seqs] to [S_v, S_v, H_v, n_seqs]
495+ // to make it compatible with k_expanded for element-wise multiplication
496+ ggml_tensor * gated_state_reshaped = ggml_reshape_4d (ctx, gated_state, S_v, S_v, H_v, n_seqs);
497+ cb (gated_state_reshaped, " gated_state_reshaped" , il);
498+
499+ ggml_tensor * state_k_product = ggml_mul (ctx, gated_state_reshaped, k_t );
500+ cb (state_k_product, " state_k_product" , il);
476501
477- // Compute kv_memory from state and key
478- // kv_mem = (state * k.unsqueeze(-1)).sum(dim=-2)
479-
480- // Reshape gated_state from [S_v, S_v*H_v, 1, n_seqs] to [S_v, S_v, H_v, n_seqs]
481- // to make it compatible with k_expanded for element-wise multiplication
482- ggml_tensor * gated_state_reshaped = ggml_reshape_4d (ctx, gated_state, S_v, S_v, H_v, n_seqs);
483- cb (gated_state_reshaped, " gated_state_reshaped" , il);
484-
485- ggml_tensor * state_k_product = ggml_mul (ctx, gated_state_reshaped, k_t );
486- cb (state_k_product, " state_k_product" , il);
502+ ggml_tensor * kv_memory = ggml_sum_rows (ctx, ggml_cont (ctx, ggml_transpose (ctx, state_k_product)));
503+ cb (kv_memory, " kv_memory" , il);
487504
488- ggml_tensor * kv_memory = ggml_sum_rows (ctx, ggml_cont (ctx, ggml_transpose (ctx, state_k_product)));
489- cb (kv_memory, " kv_memory" , il);
505+ // Compute delta = (v - kv_memory) * beta
506+ ggml_tensor * v_diff = ggml_sub (ctx, v_t , kv_memory);
507+ ggml_tensor * delta = ggml_mul (ctx, v_diff, beta_t );
508+ cb (delta, " delta" , il);
490509
491- // Compute delta = (v - kv_memory) * beta
492- ggml_tensor * v_diff = ggml_sub (ctx, v_t , kv_memory);
493- ggml_tensor * delta = ggml_mul (ctx, v_diff, beta_t );
494- cb (delta, " delta" , il);
510+ // Update state = state + k * delta
511+ // In the reference: last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
512+ ggml_tensor * delta_t = ggml_transpose (ctx, delta);
495513
496- // Update state = state + k * delta
497- // In the reference: last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
498- ggml_tensor * delta_t = ggml_transpose (ctx, delta);
514+ // Will need to broadcast here since GGML doesn't support auto-double-broadcasting on mul
515+ ggml_tensor * delta_t_broadcast = ggml_repeat_4d (ctx, delta_t , S_v, S_v, H_v, n_seqs);
516+ ggml_tensor * k_t_broadcast = ggml_repeat_4d (ctx, k_t , S_v, S_v, H_v, n_seqs);
517+ ggml_tensor * k_delta_product = ggml_mul (ctx, k_t_broadcast, delta_t_broadcast);
518+ cb (k_delta_product, " k_delta" , il);
499519
500- // Will need to broadcast here since GGML doesn't support auto-double-broadcasting on mul
501- ggml_tensor * delta_t_broadcast = ggml_repeat_4d (ctx, delta_t , S_v, S_v, H_v, n_seqs);
502- ggml_tensor * k_t_broadcast = ggml_repeat_4d (ctx, k_t , S_v, S_v, H_v, n_seqs);
503- ggml_tensor * k_delta_product = ggml_mul (ctx, k_t_broadcast, delta_t_broadcast);
504- cb (k_delta_product, " k_delta" , il);
520+ state = ggml_add (ctx, gated_state_reshaped, k_delta_product);
521+ cb (state, " updated_state" , il);
505522
506- ggml_tensor * updated_state = ggml_add (ctx, gated_state_reshaped, k_delta_product);
507- cb (updated_state, " updated_state" , il);
508-
509- ggml_tensor * state_q_product = ggml_mul (ctx, updated_state, q_t );
510- cb (state_q_product, " state_q_product" , il);
511- ggml_tensor * output = ggml_sum_rows (ctx, ggml_cont (ctx, ggml_transpose (ctx, state_q_product)));
512- cb (output, " output" , il);
523+ ggml_tensor * state_q_product = ggml_mul (ctx, state, q_t );
524+ cb (state_q_product, " state_q_product" , il);
525+
526+ ggml_tensor * output = ggml_sum_rows (ctx, ggml_cont (ctx, ggml_transpose (ctx, state_q_product)));
527+ cb (output, " output" , il);
513528
529+ if (final_output == nullptr ) {
530+ final_output = output;
531+ } else {
532+ final_output = ggml_concat (ctx, final_output, output, 0 );
533+ }
534+ }
535+
514536 // Concatenate output and updated_state into a single tensor
515537 // First, flatten both tensors to 1D
516- ggml_tensor * output_1d = ggml_cont_1d (ctx, output , ggml_nelements (output ));
517- ggml_tensor * updated_state_1d = ggml_cont_1d (ctx, updated_state , ggml_nelements (updated_state ));
538+ ggml_tensor * output_1d = ggml_cont_1d (ctx, final_output , ggml_nelements (final_output ));
539+ ggml_tensor * updated_state_1d = ggml_cont_1d (ctx, state , ggml_nelements (state ));
518540
519541 // Concatenate them: [output, updated_state]
520542 ggml_tensor * result = ggml_concat (ctx, output_1d, updated_state_1d, 0 );
0 commit comments