Skip to content

Commit 17240ea

Browse files
committed
Order stuff around
1 parent 1579bcb commit 17240ea

File tree

1 file changed

+33
-39
lines changed

1 file changed

+33
-39
lines changed

src/models/llm_build_qwen3next.cpp

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -386,14 +386,13 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
386386
const auto * mctx_cur = inp->mctx;
387387

388388
const int64_t d_inner = hparams.ssm_d_inner;
389-
const int64_t n_heads = hparams.ssm_dt_rank;
390-
const int64_t head_dim = d_inner / n_heads;
389+
391390
const int64_t n_seqs = ubatch.n_seqs;
392391

393392
const int64_t head_k_dim = hparams.ssm_d_state;
394-
const int64_t head_v_dim = hparams.ssm_d_state;
395393
const int64_t num_k_heads = hparams.ssm_n_group;
396394
const int64_t num_v_heads = hparams.ssm_dt_rank;
395+
const int64_t head_v_dim = d_inner / num_v_heads;
397396

398397
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
399398

@@ -408,7 +407,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
408407
ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
409408
cb(mixed_ba, "linear_attn_mixed_ba", il);
410409

411-
int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads;
410+
int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
412411
ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
413412

414413
// Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
@@ -441,63 +440,58 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
441440
ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus
442441
cb(gate, "gate", il);
443442

444-
// Get convolution states from cache
445-
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
446-
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
447-
448-
// Build the convolution states tensor
449-
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
450-
cb(conv_states, "conv_states", il);
451-
452-
// Split mixed_qkvz into query, key, value, z
443+
// Split mixed_qkvz into query, key, value, z
453444
int64_t split_sizes_qkvz[4] = {
454445
head_k_dim, // query size
455446
head_k_dim, // key size
456447
head_v_dim * num_v_heads / num_k_heads, // value size
457448
head_v_dim * num_v_heads / num_k_heads // z size
458449
};
459450

460-
ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
461-
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0));
451+
ggml_tensor * query = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
452+
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
462453
cb(query, "q", il);
463454

464-
ggml_tensor * key = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
455+
ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
465456
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
466-
split_sizes_qkvz[0] * sizeof(float)));
457+
split_sizes_qkvz[0] * sizeof(float));
467458
cb(key, "k", il);
468459

469-
ggml_tensor * value = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
460+
ggml_tensor * value = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
470461
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
471-
(split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)));
462+
(split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float));
472463
cb(value, "v", il);
473464

474-
ggml_tensor * z = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
465+
ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
475466
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
476-
(split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)));
467+
(split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
477468
cb(z, "z", il);
478469

479-
// Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
480-
ggml_tensor * value_reshaped =
481-
ggml_reshape_4d(ctx0, ggml_cont(ctx0, value), head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
482-
ggml_tensor * z_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, z), head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
483-
484-
GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value_reshaped) +
485-
ggml_nelements(z_reshaped) ==
470+
GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value) +
471+
ggml_nelements(z) ==
486472
ggml_nelements(mixed_qkvz));
487473

488474
// After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
489475
// query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
490-
ggml_tensor * query_flat = ggml_reshape_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
476+
ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
491477
cb(query_flat, "query_flat", il);
492478

493479
// key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
494-
ggml_tensor * key_flat = ggml_reshape_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
480+
ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
495481
cb(key_flat, "key_flat", il);
496482

497483
// value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
498-
ggml_tensor * value_flat = ggml_reshape_3d(ctx0, value_reshaped, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
484+
ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
499485
cb(value_flat, "value_flat", il);
500486

487+
// Get convolution states from cache
488+
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
489+
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
490+
491+
// Build the convolution states tensor
492+
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
493+
cb(conv_states, "conv_states", il);
494+
501495
// Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
502496
ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
503497
qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
@@ -578,7 +572,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
578572
beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
579573

580574
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
581-
state = ggml_reshape_4d(ctx0, state, head_dim, head_dim * n_heads, 1, n_seqs);
575+
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
582576

583577
// if head keys and value keys are different, repeat to force tensors into matching shapes
584578
if (num_k_heads != num_v_heads) {
@@ -598,17 +592,17 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
598592
cb(attn_out, "attn_out", il);
599593

600594
// The tensors were concatenated 1d, so we need to extract them 1d as well
601-
const int64_t output_flat_size = head_dim * n_heads * n_seq_tokens * n_seqs;
595+
const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs;
602596
ggml_tensor * attn_out_1d =
603597
ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
604598
cb(attn_out_1d, "attn_out_1d", il);
605599

606-
ggml_tensor * attn_out_final = ggml_cont(ctx0, ggml_permute(ctx0, ggml_cont_4d(ctx0, attn_out_1d, head_dim, n_seq_tokens, n_heads, n_seqs), 0, 2, 1, 3));
600+
ggml_tensor * attn_out_final = ggml_cont(ctx0, ggml_permute(ctx0, ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, n_seq_tokens, num_v_heads, n_seqs), 0, 2, 1, 3));
607601
cb(attn_out_final, "attn_out_final", il);
608602

609603
// Extract the state part (second part of the concatenated tensor)
610604
// State starts after n_tokens elements along dimension 1
611-
const int64_t state_flat_size = head_dim * head_dim * n_heads * n_seqs;
605+
const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs;
612606

613607
ggml_tensor * state_1d = ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out));
614608
cb(state_1d, "state_1d", il);
@@ -620,19 +614,19 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
620614

621615
// Reshape both attn_out_final and z to 2D tensors for normalization
622616
// attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
623-
ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out_final), head_dim, n_heads * n_seq_tokens * n_seqs);
617+
ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out_final), head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
624618

625619
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
626-
ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_seq_tokens * n_seqs);
620+
ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
627621

628622
// Apply gated normalization: self.norm(core_attn_out, z)
629623
ggml_tensor * attn_out_norm = build_q3n_gated_norm(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
630624

631625
// Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
632-
ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, attn_out_norm, head_dim, n_heads, n_seq_tokens, n_seqs);
626+
ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, attn_out_norm, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
633627

634628
// Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
635-
ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_seq_tokens, n_seqs);
629+
ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
636630
cb(final_output, "final_output", il);
637631

638632
// Output projection

0 commit comments

Comments
 (0)