Skip to content

Commit a2c7b67

Browse files
committed
Proper handling for n_tokens > GGML_DELTA_NET_CHUNK
1 parent c1e46f6 commit a2c7b67

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

src/models/llm_build_qwen3next.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,24 +252,28 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
252252
GGML_ASSERT(beta->ne[0] % GGML_DELTA_NET_CHUNK == 0 && beta->ne[1] == H_k && beta->ne[2] == 1 && beta->ne[3] == n_seqs);
253253
GGML_ASSERT(g->ne[0] % GGML_DELTA_NET_CHUNK == 0 && g->ne[1] == H_k && g->ne[2] == 1 && g->ne[3] == n_seqs);
254254

255-
ggml_tensor * beta_unsq = ggml_cont_4d(ctx, beta, 1, num_chunks * GGML_DELTA_NET_CHUNK, H_k, n_seqs);
256-
ggml_tensor * beta_bcast = ggml_repeat_4d(ctx, beta_unsq, S_v, num_chunks * GGML_DELTA_NET_CHUNK, H_k, n_seqs);
255+
ggml_tensor * beta_unsq = ggml_cont_4d(ctx, beta, 1, GGML_DELTA_NET_CHUNK * num_chunks, H_k, n_seqs);
256+
ggml_tensor * beta_bcast = ggml_repeat_4d(ctx, beta_unsq, S_v, GGML_DELTA_NET_CHUNK * num_chunks, H_k, n_seqs);
257257
cb(beta_unsq, "beta_unsq", il);
258258
cb(beta_bcast, "beta_bcast", il);
259259

260260
struct ggml_tensor * v_beta = ggml_mul(ctx, v, beta_bcast);
261+
v_beta = ggml_reshape_4d(ctx, v_beta, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
261262
cb(v_beta, "v_beta", il);
262263
struct ggml_tensor * k_beta = ggml_mul(ctx, k, beta_bcast);
264+
k_beta = ggml_reshape_4d(ctx, k_beta, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
263265
cb(k_beta, "k_beta", il);
266+
k = ggml_reshape_4d(ctx, k, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
267+
cb(k_beta, "k_reshape", il);
264268
struct ggml_tensor * g_cumsum = ggml_cumsum(ctx, g);
265269
cb(g_cumsum, "g_cumsum", il);
266270

267-
struct ggml_tensor * gcs_i = ggml_cont_4d(ctx, g_cumsum, num_chunks * GGML_DELTA_NET_CHUNK, 1, H_v, n_seqs); // [chunk_size, 1, n_tokens, n_seqs]
268-
struct ggml_tensor * gcs_j = ggml_cont_4d(ctx, g_cumsum, 1, num_chunks * GGML_DELTA_NET_CHUNK, H_v, n_seqs); // [1, chunk_size, n_tokens, n_seqs]
271+
struct ggml_tensor * gcs_i = ggml_cont_4d(ctx, g_cumsum, GGML_DELTA_NET_CHUNK, 1, num_chunks * H_v, n_seqs); // [chunk_size, 1, n_tokens, n_seqs]
272+
struct ggml_tensor * gcs_j = ggml_cont_4d(ctx, g_cumsum, 1, GGML_DELTA_NET_CHUNK, num_chunks * H_v, n_seqs); // [1, chunk_size, n_tokens, n_seqs]
269273

270274
// Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs]
271-
struct ggml_tensor * gcs_i_broadcast = ggml_repeat_4d(ctx, gcs_i, num_chunks * GGML_DELTA_NET_CHUNK, num_chunks * GGML_DELTA_NET_CHUNK, H_v, n_seqs); // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
272-
struct ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx, gcs_j, num_chunks * GGML_DELTA_NET_CHUNK, num_chunks * GGML_DELTA_NET_CHUNK, H_v, n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
275+
struct ggml_tensor * gcs_i_broadcast = ggml_repeat_4d(ctx, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v, n_seqs); // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
276+
struct ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx, gcs_j, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v, n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
273277

274278
struct ggml_tensor * decay_mask = ggml_sub(ctx, gcs_j_broadcast, gcs_i_broadcast);
275279
cb(decay_mask, "sub", il);

0 commit comments

Comments
 (0)