Skip to content

Commit 5306640

Browse files
committed
All's well that ends in a well
1 parent 232ec56 commit 5306640

File tree

3 files changed

+174
-8
lines changed

3 files changed

+174
-8
lines changed

src/models/llm_build_qwen3next.cpp

Lines changed: 156 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,144 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
382382
return result;
383383
}
384384

385+
// delta_net_recurrent
386+
// Recurrent version of delta_net for sequence_length = 1
387+
struct ggml_tensor * llm_build_qwen3next::delta_net_recurrent(
388+
struct ggml_context * ctx,
389+
struct ggml_tensor * q,
390+
struct ggml_tensor * k,
391+
struct ggml_tensor * v,
392+
struct ggml_tensor * g,
393+
struct ggml_tensor * beta,
394+
struct ggml_tensor * state,
395+
bool use_qk_l2norm,
396+
float eps_norm,
397+
const int il
398+
) {
399+
GGML_ASSERT(ggml_is_contiguous(q));
400+
GGML_ASSERT(ggml_is_contiguous(k));
401+
GGML_ASSERT(ggml_is_contiguous(v));
402+
GGML_ASSERT(ggml_is_contiguous(g));
403+
GGML_ASSERT(ggml_is_contiguous(beta));
404+
GGML_ASSERT(ggml_is_contiguous(state));
405+
406+
const int64_t S_k = q->ne[0];
407+
const int64_t H_k = q->ne[1];
408+
const int64_t n_tokens = q->ne[2];
409+
const int64_t n_seqs = q->ne[3];
410+
411+
const int64_t S_v = v->ne[0];
412+
const int64_t H_v = v->ne[1];
413+
414+
GGML_ASSERT(n_tokens == 1); // Recurrent version only supports sequence_length = 1
415+
GGML_ASSERT(v->ne[2] == n_tokens);
416+
GGML_ASSERT(k->ne[2] == n_tokens);
417+
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
418+
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
419+
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
420+
421+
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
422+
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && q->ne[3] == n_seqs);
423+
424+
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
425+
426+
cb(q, "q_prenorm", il);
427+
cb(k, "k_prenorm", il);
428+
429+
if (use_qk_l2norm) {
430+
q = ggml_l2_norm(ctx, q, eps_norm);
431+
k = ggml_l2_norm(ctx, k, eps_norm);
432+
}
433+
434+
cb(k, "k_postnorm", il);
435+
cb(q, "q_prescale", il);
436+
437+
float scale = 1.0f / sqrtf(S_v);
438+
q = ggml_scale(ctx, q, scale);
439+
440+
cb(beta, "beta_raw", il);
441+
beta = ggml_sigmoid(ctx, beta);
442+
443+
cb(q, "q_postscale", il);
444+
cb(beta, "beta_sigmoid", il);
445+
446+
// Reshape tensors for recurrent computation
447+
// From [S_k, H_k, n_tokens, n_seqs] to [S_k, n_tokens, H_k, n_seqs]
448+
q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3));
449+
cb(q, "q_reshape", il);
450+
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3));
451+
cb(k, "k_reshape", il);
452+
v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3));
453+
cb(v, "v_reshape", il);
454+
455+
beta = ggml_cont(ctx, ggml_permute(ctx, beta, 1, 2, 0, 3));
456+
cb(beta, "beta_reshape", il);
457+
458+
g = ggml_cont(ctx, ggml_permute(ctx, g, 2, 0, 3, 1));
459+
cb(g, "g_permute", il);
460+
461+
ggml_tensor * q_t = ggml_cont_4d(ctx, q, 1, S_k, H_k, n_seqs);
462+
ggml_tensor * k_t = ggml_cont_4d(ctx, k, 1, S_k, H_k, n_seqs);
463+
ggml_tensor * v_t = ggml_cont_4d(ctx, v, 1, S_v, H_k, n_seqs);
464+
ggml_tensor * g_t = ggml_cont_4d(ctx, g, 1, 1, H_k, n_seqs);
465+
ggml_tensor * beta_t = ggml_cont_4d(ctx, beta, 1, 1, H_k, n_seqs);
466+
state = ggml_cont_4d(ctx, state, S_v, S_v, H_k, n_seqs);
467+
468+
// Apply exponential to gate: exp(g)
469+
ggml_tensor * g_exp = ggml_exp(ctx, g_t);
470+
cb(g_exp, "g_exp", il);
471+
472+
// Apply gate to state: state = state * exp(g)
473+
ggml_tensor * gated_state = ggml_mul(ctx, state, g_exp);
474+
cb(gated_state, "gated_state", il);
475+
476+
// Compute kv_memory from state and key
477+
// kv_mem = (state * k.unsqueeze(-1)).sum(dim=-2)
478+
479+
// Reshape gated_state from [S_v, S_v*H_v, 1, n_seqs] to [S_v, S_v, H_v, n_seqs]
480+
// to make it compatible with k_expanded for element-wise multiplication
481+
ggml_tensor * gated_state_reshaped = ggml_reshape_4d(ctx, gated_state, S_v, S_v, H_v, n_seqs);
482+
cb(gated_state_reshaped, "gated_state_reshaped", il);
483+
484+
ggml_tensor * state_k_product = ggml_mul(ctx, gated_state_reshaped, k_t);
485+
cb(state_k_product, "state_k_product", il);
486+
487+
ggml_tensor * kv_memory = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_k_product)));
488+
cb(kv_memory, "kv_memory", il);
489+
490+
// Compute delta = (v - kv_memory) * beta
491+
ggml_tensor * v_diff = ggml_sub(ctx, v_t, kv_memory);
492+
ggml_tensor * delta = ggml_mul(ctx, v_diff, beta_t);
493+
cb(delta, "delta", il);
494+
495+
// Update state = state + k * delta
496+
// In the reference: last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
497+
ggml_tensor * delta_t = ggml_transpose(ctx, delta);
498+
499+
// Will need to broadcast here since GGML doesn't support auto-double-broadcasting on mul
500+
ggml_tensor * delta_t_broadcast = ggml_repeat_4d(ctx, delta_t, S_v, S_v, H_v, n_seqs);
501+
ggml_tensor * k_t_broadcast = ggml_repeat_4d(ctx, k_t, S_v, S_v, H_v, n_seqs);
502+
ggml_tensor * k_delta_product = ggml_mul(ctx, k_t_broadcast, delta_t_broadcast);
503+
cb(k_delta_product, "k_delta", il);
504+
505+
ggml_tensor * updated_state = ggml_add(ctx, gated_state_reshaped, k_delta_product);
506+
cb(updated_state, "updated_state", il);
507+
508+
ggml_tensor * state_q_product = ggml_mul(ctx, updated_state, q_t);
509+
cb(state_q_product, "state_q_product", il);
510+
ggml_tensor * output = ggml_sum_rows(ctx, ggml_cont(ctx, ggml_transpose(ctx, state_q_product)));
511+
cb(output, "output", il);
512+
513+
// Concatenate output and updated_state into a single tensor
514+
// First, flatten both tensors to 1D
515+
ggml_tensor * output_1d = ggml_cont_1d(ctx, output, ggml_nelements(output));
516+
ggml_tensor * updated_state_1d = ggml_cont_1d(ctx, updated_state, ggml_nelements(updated_state));
517+
518+
// Concatenate them: [output, updated_state]
519+
ggml_tensor * result = ggml_concat(ctx, output_1d, updated_state_1d, 0);
520+
return result;
521+
}
522+
385523

386524
ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_input_rs * inp,
387525
ggml_tensor * cur,
@@ -402,6 +540,8 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
402540

403541
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
404542

543+
const auto kv_head = mctx_cur->get_head();
544+
405545
GGML_ASSERT(n_seqs != 0);
406546
GGML_ASSERT(ubatch.equal_seqs());
407547
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
@@ -494,6 +634,8 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
494634
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
495635
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
496636

637+
bool is_generation = mctx_cur->get_rs_z() < 0;
638+
497639
// Build the convolution states tensor
498640
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
499641
cb(conv_states, "conv_states", il);
@@ -528,7 +670,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
528670
cb(last_conv_states, "last_conv_states", il);
529671

530672
ggml_tensor * state_update_target = ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
531-
mctx_cur->get_head() * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
673+
kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
532674
cb(state_update_target, "state_update_target", il);
533675

534676
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
@@ -584,6 +726,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
584726

585727
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
586728
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
729+
cb(state, "state_predelta", il);
587730

588731
// if head keys and value keys are different, repeat to force tensors into matching shapes
589732
if (num_k_heads != num_v_heads) {
@@ -598,8 +741,15 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
598741
cb(k_conv, "k_conv_predelta", il);
599742
cb(v_conv, "v_conv_predelta", il);
600743

601-
// Call the new delta_net function with the corrected flow
602-
ggml_tensor * attn_out = delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, hparams.f_norm_rms_eps, il);
744+
// Choose between delta_net and delta_net_recurrent based on generation mode
745+
ggml_tensor * attn_out;
746+
if (is_generation) {
747+
// Use delta_net_recurrent for single token generation
748+
attn_out = delta_net_recurrent(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, hparams.f_norm_rms_eps, il);
749+
} else {
750+
// Use regular delta_net for prompt processing
751+
attn_out = delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, true, hparams.f_norm_rms_eps, il);
752+
}
603753
cb(attn_out, "attn_out", il);
604754

605755
// The tensors were concatenated 1d, so we need to extract them 1d as well
@@ -621,7 +771,9 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
621771
// Update the recurrent states
622772
ggml_build_forward_expand(gf,
623773
ggml_cpy(ctx0, state_1d, ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
624-
hparams.n_embd_s() * mctx_cur->get_head() * ggml_element_size(ssm_states_all))));
774+
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
775+
776+
GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out));
625777

626778
// Reshape both attn_out_final and z to 2D tensors for normalization
627779
// attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]

src/models/llm_build_qwen3next.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,19 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
2323
float eps_norm,
2424
const int il);
2525

26+
// delta_net_recurrent
27+
struct ggml_tensor * delta_net_recurrent(
28+
struct ggml_context * ctx,
29+
struct ggml_tensor * q,
30+
struct ggml_tensor * k,
31+
struct ggml_tensor * v,
32+
struct ggml_tensor * g,
33+
struct ggml_tensor * beta,
34+
struct ggml_tensor * state,
35+
bool use_qk_l2norm,
36+
float eps_norm,
37+
const int il);
38+
2639
ggml_tensor * build_qwen3next_attention_layer(ggml_tensor * cur,
2740
ggml_tensor * inp_pos,
2841
llm_graph_input_attn_kv * inp_attn,

tools/main/main.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,8 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
242242
if (!ggml_is_quantized(t->type)) {
243243
uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data();
244244
ggml_print_tensor(data, t->type, t->ne, t->nb, 3);
245-
if (std::string(t->name).substr(0, std::string("post_moe-").size()) == "post_moe-") {
245+
if (std::string(t->name).substr(0, std::string("post_moe-").size()) == "post_moe-" ||
246+
std::string(t->name).substr(0, std::string("state_1d-").size()) == "state_1d-") {
246247
if (cb_data->tensors.count(t->name) == 0) {
247248
cb_data->tensors[t->name] = 1;
248249
} else {
@@ -311,9 +312,9 @@ int main(int argc, char ** argv) {
311312
std::vector<common_chat_msg> chat_msgs;
312313

313314
// load the model and apply lora adapter, if any
314-
callback_data cb_data;
315-
params.cb_eval = ggml_debug;
316-
params.cb_eval_user_data = &cb_data;
315+
// callback_data cb_data;
316+
// params.cb_eval = ggml_debug;
317+
// params.cb_eval_user_data = &cb_data;
317318
LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
318319
common_init_result llama_init = common_init_from_params(params);
319320

0 commit comments

Comments
 (0)