Skip to content

Commit dbd4d97

Browse files
committed
Fix cb calls
1 parent 32dcee4 commit dbd4d97

File tree

1 file changed

+42
-42
lines changed

1 file changed

+42
-42
lines changed

src/models/llm_build_qwen3next.cpp

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -311,46 +311,6 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
311311
int64_t ba_new_dim = 2 * num_v_heads / num_k_heads;
312312
ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tokens, n_seqs);
313313

314-
// Split mixed_qkvz into query, key, value, z
315-
int64_t split_sizes_qkvz[4] = {
316-
head_k_dim, // query size
317-
head_k_dim, // key size
318-
head_v_dim * num_v_heads / num_k_heads, // value size
319-
head_v_dim * num_v_heads / num_k_heads // z size
320-
};
321-
322-
ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads,
323-
n_tokens, n_seqs, split_sizes_qkvz[0] * sizeof(float),
324-
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], 0));
325-
cb(query, "q", il);
326-
327-
ggml_tensor * key =
328-
ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, n_seqs,
329-
split_sizes_qkvz[1] * sizeof(float), mixed_qkvz_reshaped->nb[1],
330-
mixed_qkvz_reshaped->nb[2], split_sizes_qkvz[0] * sizeof(float)));
331-
cb(query, "k", il);
332-
333-
ggml_tensor * value =
334-
ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tokens, n_seqs,
335-
split_sizes_qkvz[2] * sizeof(float), mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2],
336-
(split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float));
337-
cb(query, "v", il);
338-
339-
ggml_tensor * z =
340-
ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tokens, n_seqs,
341-
split_sizes_qkvz[3] * sizeof(float), mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2],
342-
(split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
343-
cb(query, "z", il);
344-
345-
// Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
346-
ggml_tensor * value_reshaped =
347-
ggml_reshape_4d(ctx0, ggml_cont(ctx0, value), head_v_dim, num_v_heads, n_tokens, n_seqs);
348-
ggml_tensor * z_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, z), head_v_dim, num_v_heads, n_tokens, n_seqs);
349-
350-
GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value_reshaped) +
351-
ggml_nelements(z_reshaped) ==
352-
ggml_nelements(mixed_qkvz));
353-
354314
// Split mixed_ba into b and a (beta and alpha parameters)
355315
int64_t split_sizes_ba[2] = {
356316
num_v_heads / num_k_heads, // beta size
@@ -360,12 +320,12 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
360320
ggml_tensor * b =
361321
ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tokens, n_seqs,
362322
split_sizes_ba[0] * sizeof(float), mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], 0);
363-
cb(query, "b", il);
323+
cb(b, "b", il);
364324

365325
ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tokens, n_seqs,
366326
split_sizes_ba[1] * sizeof(float), mixed_ba_reshaped->nb[1],
367327
mixed_ba_reshaped->nb[2], split_sizes_ba[0] * sizeof(float));
368-
cb(query, "a", il);
328+
cb(a, "a", il);
369329

370330
// Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
371331
ggml_tensor * beta = ggml_reshape_3d(ctx0, ggml_cont(ctx0, b), num_v_heads, n_tokens, n_seqs);
@@ -390,6 +350,46 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
390350
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
391351
cb(conv_states, "conv_states", il);
392352

353+
// Split mixed_qkvz into query, key, value, z
354+
int64_t split_sizes_qkvz[4] = {
355+
head_k_dim, // query size
356+
head_k_dim, // key size
357+
head_v_dim * num_v_heads / num_k_heads, // value size
358+
head_v_dim * num_v_heads / num_k_heads // z size
359+
};
360+
361+
ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads,
362+
n_tokens, n_seqs, split_sizes_qkvz[0] * sizeof(float),
363+
mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], 0));
364+
cb(query, "q", il);
365+
366+
ggml_tensor * key =
367+
ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, n_seqs,
368+
split_sizes_qkvz[1] * sizeof(float), mixed_qkvz_reshaped->nb[1],
369+
mixed_qkvz_reshaped->nb[2], split_sizes_qkvz[0] * sizeof(float)));
370+
cb(key, "k", il);
371+
372+
ggml_tensor * value =
373+
ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tokens, n_seqs,
374+
split_sizes_qkvz[2] * sizeof(float), mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2],
375+
(split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float));
376+
cb(value, "v", il);
377+
378+
ggml_tensor * z =
379+
ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tokens, n_seqs,
380+
split_sizes_qkvz[3] * sizeof(float), mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2],
381+
(split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
382+
cb(z, "z", il);
383+
384+
// Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
385+
ggml_tensor * value_reshaped =
386+
ggml_reshape_4d(ctx0, ggml_cont(ctx0, value), head_v_dim, num_v_heads, n_tokens, n_seqs);
387+
ggml_tensor * z_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, z), head_v_dim, num_v_heads, n_tokens, n_seqs);
388+
389+
GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value_reshaped) +
390+
ggml_nelements(z_reshaped) ==
391+
ggml_nelements(mixed_qkvz));
392+
393393
// After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
394394
// query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
395395
ggml_tensor * query_flat = ggml_reshape_3d(ctx0, query, head_k_dim * num_k_heads, n_tokens, n_seqs);

0 commit comments

Comments
 (0)