@@ -252,24 +252,28 @@ struct ggml_tensor * llm_build_qwen3next::delta_net(
252252 GGML_ASSERT (beta->ne [0 ] % GGML_DELTA_NET_CHUNK == 0 && beta->ne [1 ] == H_k && beta->ne [2 ] == 1 && beta->ne [3 ] == n_seqs);
253253 GGML_ASSERT (g->ne [0 ] % GGML_DELTA_NET_CHUNK == 0 && g->ne [1 ] == H_k && g->ne [2 ] == 1 && g->ne [3 ] == n_seqs);
254254
255- ggml_tensor * beta_unsq = ggml_cont_4d (ctx, beta, 1 , num_chunks * GGML_DELTA_NET_CHUNK , H_k, n_seqs);
256- ggml_tensor * beta_bcast = ggml_repeat_4d (ctx, beta_unsq, S_v, num_chunks * GGML_DELTA_NET_CHUNK , H_k, n_seqs);
255+ ggml_tensor * beta_unsq = ggml_cont_4d (ctx, beta, 1 , GGML_DELTA_NET_CHUNK * num_chunks , H_k, n_seqs);
256+ ggml_tensor * beta_bcast = ggml_repeat_4d (ctx, beta_unsq, S_v, GGML_DELTA_NET_CHUNK * num_chunks , H_k, n_seqs);
257257 cb (beta_unsq, " beta_unsq" , il);
258258 cb (beta_bcast, " beta_bcast" , il);
259259
260260 struct ggml_tensor * v_beta = ggml_mul (ctx, v, beta_bcast);
261+ v_beta = ggml_reshape_4d (ctx, v_beta, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
261262 cb (v_beta, " v_beta" , il);
262263 struct ggml_tensor * k_beta = ggml_mul (ctx, k, beta_bcast);
264+ k_beta = ggml_reshape_4d (ctx, k_beta, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
263265 cb (k_beta, " k_beta" , il);
266+ k = ggml_reshape_4d (ctx, k, S_v, GGML_DELTA_NET_CHUNK, H_k * num_chunks, n_seqs);
267+ cb (k_beta, " k_reshape" , il);
264268 struct ggml_tensor * g_cumsum = ggml_cumsum (ctx, g);
265269 cb (g_cumsum, " g_cumsum" , il);
266270
267- struct ggml_tensor * gcs_i = ggml_cont_4d (ctx, g_cumsum, num_chunks * GGML_DELTA_NET_CHUNK, 1 , H_v, n_seqs); // [chunk_size, 1, n_tokens, n_seqs]
268- struct ggml_tensor * gcs_j = ggml_cont_4d (ctx, g_cumsum, 1 , num_chunks * GGML_DELTA_NET_CHUNK, H_v, n_seqs); // [1, chunk_size, n_tokens, n_seqs]
271+ struct ggml_tensor * gcs_i = ggml_cont_4d (ctx, g_cumsum, GGML_DELTA_NET_CHUNK, 1 , num_chunks * H_v, n_seqs); // [chunk_size, 1, n_tokens, n_seqs]
272+ struct ggml_tensor * gcs_j = ggml_cont_4d (ctx, g_cumsum, 1 , GGML_DELTA_NET_CHUNK, num_chunks * H_v, n_seqs); // [1, chunk_size, n_tokens, n_seqs]
269273
270274 // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs]
271- struct ggml_tensor * gcs_i_broadcast = ggml_repeat_4d (ctx, gcs_i, num_chunks * GGML_DELTA_NET_CHUNK, num_chunks * GGML_DELTA_NET_CHUNK, H_v, n_seqs); // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
272- struct ggml_tensor * gcs_j_broadcast = ggml_repeat_4d (ctx, gcs_j, num_chunks * GGML_DELTA_NET_CHUNK, num_chunks * GGML_DELTA_NET_CHUNK, H_v, n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
275+ struct ggml_tensor * gcs_i_broadcast = ggml_repeat_4d (ctx, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v, n_seqs); // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
276+ struct ggml_tensor * gcs_j_broadcast = ggml_repeat_4d (ctx, gcs_j, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v, n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs]
273277
274278 struct ggml_tensor * decay_mask = ggml_sub (ctx, gcs_j_broadcast, gcs_i_broadcast);
275279 cb (decay_mask, " sub" , il);
0 commit comments