Skip to content

Commit 5728642

Browse files
committed
Handle case with more than one token per seq with elegant loop plus completely not crazy change to max nodes ;)
1 parent c2a82a1 commit 5728642

File tree

2 files changed

+68
-46
lines changed

2 files changed

+68
-46
lines changed

src/llama-context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1362,7 +1362,7 @@ void llama_context::output_reorder() {
13621362
//
13631363

13641364
uint32_t llama_context::graph_max_nodes() const {
1365-
return std::max<uint32_t>(1024u, 32u*model.n_tensors());
1365+
return std::max<uint32_t>(16384, 512u*model.n_tensors());
13661366
}
13671367

13681368
llm_graph_result * llama_context::get_gf_res_reserve() const {

src/models/llm_build_qwen3next.cpp

Lines changed: 67 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,6 @@ struct ggml_tensor * llm_build_qwen3next::delta_net_recurrent(
412412
const int64_t S_v = v->ne[0];
413413
const int64_t H_v = v->ne[1];
414414

415-
GGML_ASSERT(n_tokens == 1); // Recurrent version only supports sequence_length = 1
416415
GGML_ASSERT(v->ne[2] == n_tokens);
417416
GGML_ASSERT(k->ne[2] == n_tokens);
418417
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
@@ -459,62 +458,85 @@ struct ggml_tensor * llm_build_qwen3next::delta_net_recurrent(
459458
g = ggml_cont(ctx, ggml_permute(ctx, g, 2, 0, 3, 1));
460459
cb(g, "g_permute", il);
461460

462-
ggml_tensor * q_t = ggml_cont_4d(ctx, q, 1, S_k, H_k, n_seqs);
463-
ggml_tensor * k_t = ggml_cont_4d(ctx, k, 1, S_k, H_k, n_seqs);
464-
ggml_tensor * v_t = ggml_cont_4d(ctx, v, 1, S_v, H_k, n_seqs);
465-
ggml_tensor * g_t = ggml_cont_4d(ctx, g, 1, 1, H_k, n_seqs);
466-
ggml_tensor * beta_t = ggml_cont_4d(ctx, beta, 1, 1, H_k, n_seqs);
461+
ggml_tensor * q_tokens = ggml_cont_4d(ctx, q, n_tokens, S_k, H_k, n_seqs);
462+
ggml_tensor * k_tokens = ggml_cont_4d(ctx, k, n_tokens, S_k, H_k, n_seqs);
463+
ggml_tensor * v_tokens = ggml_cont_4d(ctx, v, n_tokens, S_v, H_k, n_seqs);
464+
ggml_tensor * g_tokens = ggml_cont_4d(ctx, g, n_tokens, 1, H_k, n_seqs);
465+
ggml_tensor * beta_tokens = ggml_cont_4d(ctx, beta, n_tokens, 1, H_k, n_seqs);
466+
467467
state = ggml_cont_4d(ctx, state, S_v, S_v, H_k, n_seqs);
468+
ggml_tensor * g_tokens_exp = ggml_exp(ctx, g_tokens);
469+
470+
ggml_tensor * final_output = nullptr;
471+
ggml_tensor * q_t, * k_t, * v_t, * g_t_exp, * beta_t;
472+
for (int i = 0; i < n_tokens; i++) { // this part is per token
473+
if (n_tokens == 1) { // don't do unnecessary reshapes / views
474+
q_t = q_tokens;
475+
k_t = k_tokens;
476+
v_t = v_tokens;
477+
g_t_exp = g_tokens_exp;
478+
beta_t = beta_tokens;
479+
} else {
480+
q_t = ggml_view_4d(ctx, q_tokens, 1, S_k, H_k, n_seqs, q_tokens->nb[1], q_tokens->nb[2], q_tokens->nb[3], i * ggml_element_size(q_tokens));
481+
k_t = ggml_view_4d(ctx, k_tokens, 1, S_k, H_k, n_seqs, k_tokens->nb[1], k_tokens->nb[2], k_tokens->nb[3], i * ggml_element_size(k_tokens));
482+
v_t = ggml_view_4d(ctx, v_tokens, 1, S_v, H_k, n_seqs, v_tokens->nb[1], v_tokens->nb[2], v_tokens->nb[3], i * ggml_element_size(v_tokens));
483+
g_t_exp = ggml_view_4d(ctx, g_tokens_exp, 1, 1, H_k, n_seqs, g_tokens_exp->nb[1], g_tokens_exp->nb[2], g_tokens_exp->nb[3], i * ggml_element_size(g_tokens_exp));
484+
beta_t = ggml_view_4d(ctx, beta_tokens, 1, 1, H_k, n_seqs, beta_tokens->nb[1], beta_tokens->nb[2], beta_tokens->nb[3], i * ggml_element_size(beta_tokens));
485+
}
468486

469-
// Apply exponential to gate: exp(g)
470-
ggml_tensor * g_exp = ggml_exp(ctx, g_t);
471-
cb(g_exp, "g_exp", il);
487+
// Apply gate to state: state = state * exp(g)
488+
ggml_tensor * gated_state = ggml_mul(ctx, state, g_t_exp);
489+
cb(gated_state, "gated_state", il);
472490

473-
// Apply gate to state: state = state * exp(g)
474-
ggml_tensor * gated_state = ggml_mul(ctx, state, g_exp);
475-
cb(gated_state, "gated_state", il);
491+
// Compute kv_memory from state and key
492+
// kv_mem = (state * k.unsqueeze(-1)).sum(dim=-2)
493+
494+
// Reshape gated_state from [S_v, S_v*H_v, 1, n_seqs] to [S_v, S_v, H_v, n_seqs]
495+
// to make it compatible with k_expanded for element-wise multiplication
496+
ggml_tensor * gated_state_reshaped = ggml_reshape_4d(ctx, gated_state, S_v, S_v, H_v, n_seqs);
497+
cb(gated_state_reshaped, "gated_state_reshaped", il);
498+
499+
ggml_tensor * state_k_product = ggml_mul(ctx, gated_state_reshaped, k_t);
500+
cb(state_k_product, "state_k_product", il);
476501

477-
// Compute kv_memory from state and key
478-
// kv_mem = (state * k.unsqueeze(-1)).sum(dim=-2)
479-
480-
// Reshape gated_state from [S_v, S_v*H_v, 1, n_seqs] to [S_v, S_v, H_v, n_seqs]
481-
// to make it compatible with k_expanded for element-wise multiplication
482-
ggml_tensor * gated_state_reshaped = ggml_reshape_4d(ctx, gated_state, S_v, S_v, H_v, n_seqs);
483-
cb(gated_state_reshaped, "gated_state_reshaped", il);
484-
485-
ggml_tensor * state_k_product = ggml_mul(ctx, gated_state_reshaped, k_t);
486-
cb(state_k_product, "state_k_product", il);
502+
ggml_tensor * kv_memory = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_k_product)));
503+
cb(kv_memory, "kv_memory", il);
487504

488-
ggml_tensor * kv_memory = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_k_product)));
489-
cb(kv_memory, "kv_memory", il);
505+
// Compute delta = (v - kv_memory) * beta
506+
ggml_tensor * v_diff = ggml_sub(ctx, v_t, kv_memory);
507+
ggml_tensor * delta = ggml_mul(ctx, v_diff, beta_t);
508+
cb(delta, "delta", il);
490509

491-
// Compute delta = (v - kv_memory) * beta
492-
ggml_tensor * v_diff = ggml_sub(ctx, v_t, kv_memory);
493-
ggml_tensor * delta = ggml_mul(ctx, v_diff, beta_t);
494-
cb(delta, "delta", il);
510+
// Update state = state + k * delta
511+
// In the reference: last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
512+
ggml_tensor * delta_t = ggml_transpose(ctx, delta);
495513

496-
// Update state = state + k * delta
497-
// In the reference: last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
498-
ggml_tensor * delta_t = ggml_transpose(ctx, delta);
514+
// Will need to broadcast here since GGML doesn't support auto-double-broadcasting on mul
515+
ggml_tensor * delta_t_broadcast = ggml_repeat_4d(ctx, delta_t, S_v, S_v, H_v, n_seqs);
516+
ggml_tensor * k_t_broadcast = ggml_repeat_4d(ctx, k_t, S_v, S_v, H_v, n_seqs);
517+
ggml_tensor * k_delta_product = ggml_mul(ctx, k_t_broadcast, delta_t_broadcast);
518+
cb(k_delta_product, "k_delta", il);
499519

500-
// Will need to broadcast here since GGML doesn't support auto-double-broadcasting on mul
501-
ggml_tensor * delta_t_broadcast = ggml_repeat_4d(ctx, delta_t, S_v, S_v, H_v, n_seqs);
502-
ggml_tensor * k_t_broadcast = ggml_repeat_4d(ctx, k_t, S_v, S_v, H_v, n_seqs);
503-
ggml_tensor * k_delta_product = ggml_mul(ctx, k_t_broadcast, delta_t_broadcast);
504-
cb(k_delta_product, "k_delta", il);
520+
state = ggml_add(ctx, gated_state_reshaped, k_delta_product);
521+
cb(state, "updated_state", il);
505522

506-
ggml_tensor * updated_state = ggml_add(ctx, gated_state_reshaped, k_delta_product);
507-
cb(updated_state, "updated_state", il);
508-
509-
ggml_tensor * state_q_product = ggml_mul(ctx, updated_state, q_t);
510-
cb(state_q_product, "state_q_product", il);
511-
ggml_tensor * output = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_q_product)));
512-
cb(output, "output", il);
523+
ggml_tensor * state_q_product = ggml_mul(ctx, state, q_t);
524+
cb(state_q_product, "state_q_product", il);
525+
526+
ggml_tensor * output = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_q_product)));
527+
cb(output, "output", il);
513528

529+
if (final_output == nullptr) {
530+
final_output = output;
531+
} else {
532+
final_output = ggml_concat(ctx, final_output, output, 0);
533+
}
534+
}
535+
514536
// Concatenate output and updated_state into a single tensor
515537
// First, flatten both tensors to 1D
516-
ggml_tensor * output_1d = ggml_cont_1d(ctx, output, ggml_nelements(output));
517-
ggml_tensor * updated_state_1d = ggml_cont_1d(ctx, updated_state, ggml_nelements(updated_state));
538+
ggml_tensor * output_1d = ggml_cont_1d(ctx, final_output, ggml_nelements(final_output));
539+
ggml_tensor * updated_state_1d = ggml_cont_1d(ctx, state, ggml_nelements(state));
518540

519541
// Concatenate them: [output, updated_state]
520542
ggml_tensor * result = ggml_concat(ctx, output_1d, updated_state_1d, 0);

0 commit comments

Comments
 (0)