Skip to content

Commit 594c1f9

Browse files
committed
QKV splits done right
1 parent dbd4d97 commit 594c1f9

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

src/models/llm_build_qwen3next.cpp

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -305,11 +305,11 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
305305
cb(mixed_ba, "linear_attn_mixed_ba", il);
306306

307307
int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads;
308-
ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tokens, n_seqs);
308+
ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tokens, n_seqs);
309309

310310
// Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
311311
int64_t ba_new_dim = 2 * num_v_heads / num_k_heads;
312-
ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tokens, n_seqs);
312+
ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tokens, n_seqs);
313313

314314
// Split mixed_ba into b and a (beta and alpha parameters)
315315
int64_t split_sizes_ba[2] = {
@@ -358,27 +358,23 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
358358
head_v_dim * num_v_heads / num_k_heads // z size
359359
};
360360

361-
ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads,
362-
n_tokens, n_seqs, split_sizes_qkvz[0] * sizeof(float),
363-
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], 0));
361+
ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_tokens, n_seqs,
362+
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0));
364363
cb(query, "q", il);
365364

366-
ggml_tensor * key =
367-
ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, n_seqs,
368-
split_sizes_qkvz[1] * sizeof(float), mixed_qkvz_reshaped->nb[1],
369-
mixed_qkvz_reshaped->nb[2], split_sizes_qkvz[0] * sizeof(float)));
365+
ggml_tensor * key = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, n_seqs,
366+
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
367+
split_sizes_qkvz[0] * sizeof(float)));
370368
cb(key, "k", il);
371369

372-
ggml_tensor * value =
373-
ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tokens, n_seqs,
374-
split_sizes_qkvz[2] * sizeof(float), mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2],
375-
(split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float));
370+
ggml_tensor * value = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tokens, n_seqs,
371+
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
372+
(split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)));
376373
cb(value, "v", il);
377374

378-
ggml_tensor * z =
379-
ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tokens, n_seqs,
380-
split_sizes_qkvz[3] * sizeof(float), mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2],
381-
(split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
375+
ggml_tensor * z = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tokens, n_seqs,
376+
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
377+
(split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)));
382378
cb(z, "z", il);
383379

384380
// Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
@@ -456,15 +452,17 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
456452

457453
// Extract the convolved Q, K, V from conv_output
458454
ggml_tensor * q_conv = ggml_cont(ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens, n_seqs,
459-
head_k_dim, conv_output->nb[1], conv_output->nb[2], 0));
455+
conv_output->nb[1], conv_output->nb[2], conv_output->nb[3], 0));
460456
cb(q_conv, "q_conv", il);
461457
ggml_tensor * k_conv = ggml_cont(
462-
ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens, n_seqs, head_k_dim, conv_output->nb[1],
463-
conv_output->nb[2], head_k_dim * num_k_heads * ggml_element_size(conv_output)));
458+
ctx0, ggml_view_4d(ctx0, conv_output, head_k_dim, num_k_heads, n_tokens, n_seqs,
459+
conv_output->nb[1], conv_output->nb[2], conv_output->nb[3],
460+
head_k_dim * num_k_heads * ggml_element_size(conv_output)));
464461
cb(q_conv, "k_conv", il);
465462
ggml_tensor * v_conv = ggml_cont(
466-
ctx0, ggml_view_4d(ctx0, conv_output, head_v_dim, num_v_heads, n_tokens, n_seqs, head_v_dim, conv_output->nb[1],
467-
conv_output->nb[2], 2 * head_k_dim * num_k_heads * ggml_element_size(conv_output)));
463+
ctx0, ggml_view_4d(ctx0, conv_output, head_v_dim, num_v_heads, n_tokens, n_seqs,
464+
conv_output->nb[1], conv_output->nb[2], conv_output->nb[3],
465+
2 * head_k_dim * num_k_heads * ggml_element_size(conv_output)));
468466
cb(q_conv, "v_conv", il);
469467

470468
ggml_build_forward_expand(gf, ssm_states_all);

0 commit comments

Comments
 (0)