Skip to content

Commit 20424d8

Browse files
committed
argh
1 parent 4136521 commit 20424d8

File tree

1 file changed

+10
-22
lines changed

1 file changed

+10
-22
lines changed

src/models/llm_build_qwen3next.cpp

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -279,24 +279,21 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
279279
cb(q, "q_postscale", il);
280280
cb(beta, "beta_sigmoid", il);
281281

282-
// First, permute to chunked format: [S_k, n_tokens, H_k, n_seqs]
282+
// Pad first along the token dimension
283+
q = ggml_pad(ctx, q, 0, 0, pad_size, 0);
284+
k = ggml_pad(ctx, k, 0, 0, pad_size, 0);
285+
v = ggml_pad(ctx, v, 0, 0, pad_size, 0);
286+
283287
q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3));
284-
cb(q, "q_reshape", il);
285288
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3));
286-
cb(k, "k_reshape", il);
287289
v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3));
288-
cb(v, "v_reshape", il);
289290

290291
beta = ggml_cont(ctx, ggml_permute(ctx, beta, 1, 2, 0, 3));
291292
cb(beta, "beta_reshape", il);
292293

293294
g = ggml_cont(ctx, ggml_permute(ctx, g, 2, 0, 3, 1));
294295
cb(g, "g_permute", il);
295-
296-
// Then, pad the second dimension (n_tokens) to chunk_size
297-
q = ggml_pad(ctx, q, 0, pad_size, 0, 0);
298-
k = ggml_pad(ctx, k, 0, pad_size, 0, 0);
299-
v = ggml_pad(ctx, v, 0, pad_size, 0, 0);
296+
300297
// ... except for beta and g, where we pad the last dimension
301298
beta = ggml_pad(ctx, beta, pad_size, 0, 0, 0);
302299
g = ggml_pad(ctx, g, pad_size, 0, 0, 0);
@@ -704,23 +701,14 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
704701
GGML_ASSERT(num_v_heads % num_k_heads == 0);
705702
int64_t repeat_factor = num_v_heads / num_k_heads;
706703

707-
// GGML tensor layout: [head_dim, num_heads, n_seq_tokens, n_seqs]
708-
709-
// Step 1: Flatten the sequence and batch dimensions to work with them more easily
710-
ggml_tensor * q_flat = ggml_reshape_2d(ctx0, q_conv, head_k_dim, num_k_heads * n_seq_tokens * n_seqs);
711-
ggml_tensor * k_flat = ggml_reshape_2d(ctx0, k_conv, head_k_dim, num_k_heads * n_seq_tokens * n_seqs);
712-
713-
// Step 2: Reshape to prepare for repeat_interleave
714-
// From [head_dim, num_k_heads * n_seq_tokens * n_seqs]
715-
// To [head_dim, num_k_heads, 1, n_seq_tokens * n_seqs]
716-
ggml_tensor * q_reshaped = ggml_reshape_4d(ctx0, q_flat, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
717-
ggml_tensor * k_reshaped = ggml_reshape_4d(ctx0, k_flat, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
704+
ggml_tensor * q_reshaped = ggml_reshape_4d(ctx0, q_conv, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
705+
ggml_tensor * k_reshaped = ggml_reshape_4d(ctx0, k_conv, head_k_dim, num_k_heads, 1, n_seq_tokens * n_seqs);
718706

719-
// Step 3: Repeat along the third dimension (the new dimension with size 1)
707+
// Repeat along the third dimension (the new dimension with size 1)
720708
ggml_tensor * q_repeated = ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
721709
ggml_tensor * k_repeated = ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs);
722710

723-
// Step 4: Reshape back to merge the head and repeat dimensions
711+
// Reshape back to merge the head and repeat dimensions
724712
// From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs]
725713
// Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs]
726714
q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs);

0 commit comments

Comments
 (0)