@@ -311,46 +311,6 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
311311 int64_t ba_new_dim = 2 * num_v_heads / num_k_heads;
312312 ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d (ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tokens, n_seqs);
313313
314- // Split mixed_qkvz into query, key, value, z
315- int64_t split_sizes_qkvz[4 ] = {
316- head_k_dim, // query size
317- head_k_dim, // key size
318- head_v_dim * num_v_heads / num_k_heads, // value size
319- head_v_dim * num_v_heads / num_k_heads // z size
320- };
321-
322- ggml_tensor * query = ggml_cont (ctx0, ggml_view_4d (ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0 ], num_k_heads,
323- n_tokens, n_seqs, split_sizes_qkvz[0 ] * sizeof (float ),
324- mixed_qkvz_reshaped->nb [1 ], mixed_qkvz_reshaped->nb [2 ], 0 ));
325- cb (query, " q" , il);
326-
327- ggml_tensor * key =
328- ggml_cont (ctx0, ggml_view_4d (ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1 ], num_k_heads, n_tokens, n_seqs,
329- split_sizes_qkvz[1 ] * sizeof (float ), mixed_qkvz_reshaped->nb [1 ],
330- mixed_qkvz_reshaped->nb [2 ], split_sizes_qkvz[0 ] * sizeof (float )));
331- cb (query, " k" , il);
332-
333- ggml_tensor * value =
334- ggml_view_4d (ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2 ], num_k_heads, n_tokens, n_seqs,
335- split_sizes_qkvz[2 ] * sizeof (float ), mixed_qkvz_reshaped->nb [1 ], mixed_qkvz_reshaped->nb [2 ],
336- (split_sizes_qkvz[0 ] + split_sizes_qkvz[1 ]) * sizeof (float ));
337- cb (query, " v" , il);
338-
339- ggml_tensor * z =
340- ggml_view_4d (ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3 ], num_k_heads, n_tokens, n_seqs,
341- split_sizes_qkvz[3 ] * sizeof (float ), mixed_qkvz_reshaped->nb [1 ], mixed_qkvz_reshaped->nb [2 ],
342- (split_sizes_qkvz[0 ] + split_sizes_qkvz[1 ] + split_sizes_qkvz[2 ]) * sizeof (float ));
343- cb (query, " z" , il);
344-
345- // Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
346- ggml_tensor * value_reshaped =
347- ggml_reshape_4d (ctx0, ggml_cont (ctx0, value), head_v_dim, num_v_heads, n_tokens, n_seqs);
348- ggml_tensor * z_reshaped = ggml_reshape_4d (ctx0, ggml_cont (ctx0, z), head_v_dim, num_v_heads, n_tokens, n_seqs);
349-
350- GGML_ASSERT (ggml_nelements (query) + ggml_nelements (key) + ggml_nelements (value_reshaped) +
351- ggml_nelements (z_reshaped) ==
352- ggml_nelements (mixed_qkvz));
353-
354314 // Split mixed_ba into b and a (beta and alpha parameters)
355315 int64_t split_sizes_ba[2 ] = {
356316 num_v_heads / num_k_heads, // beta size
@@ -360,12 +320,12 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
360320 ggml_tensor * b =
361321 ggml_view_4d (ctx0, mixed_ba_reshaped, split_sizes_ba[0 ], num_k_heads, n_tokens, n_seqs,
362322 split_sizes_ba[0 ] * sizeof (float ), mixed_ba_reshaped->nb [1 ], mixed_ba_reshaped->nb [2 ], 0 );
363- cb (query , " b" , il);
323+ cb (b , " b" , il);
364324
365325 ggml_tensor * a = ggml_view_4d (ctx0, mixed_ba_reshaped, split_sizes_ba[1 ], num_k_heads, n_tokens, n_seqs,
366326 split_sizes_ba[1 ] * sizeof (float ), mixed_ba_reshaped->nb [1 ],
367327 mixed_ba_reshaped->nb [2 ], split_sizes_ba[0 ] * sizeof (float ));
368- cb (query , " a" , il);
328+ cb (a , " a" , il);
369329
370330 // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
371331 ggml_tensor * beta = ggml_reshape_3d (ctx0, ggml_cont (ctx0, b), num_v_heads, n_tokens, n_seqs);
@@ -390,6 +350,46 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
390350 ggml_tensor * conv_states = build_rs (inp, conv_states_all, hparams.n_embd_r (), n_seqs);
391351 cb (conv_states, " conv_states" , il);
392352
353+ // Split mixed_qkvz into query, key, value, z
354+ int64_t split_sizes_qkvz[4 ] = {
355+ head_k_dim, // query size
356+ head_k_dim, // key size
357+ head_v_dim * num_v_heads / num_k_heads, // value size
358+ head_v_dim * num_v_heads / num_k_heads // z size
359+ };
360+
361+ ggml_tensor * query = ggml_cont (ctx0, ggml_view_4d (ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0 ], num_k_heads,
362+ n_tokens, n_seqs, split_sizes_qkvz[0 ] * sizeof (float ),
363+ mixed_qkvz_reshaped->nb [1 ], mixed_qkvz_reshaped->nb [2 ], 0 ));
364+ cb (query, " q" , il);
365+
366+ ggml_tensor * key =
367+ ggml_cont (ctx0, ggml_view_4d (ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1 ], num_k_heads, n_tokens, n_seqs,
368+ split_sizes_qkvz[1 ] * sizeof (float ), mixed_qkvz_reshaped->nb [1 ],
369+ mixed_qkvz_reshaped->nb [2 ], split_sizes_qkvz[0 ] * sizeof (float )));
370+ cb (key, " k" , il);
371+
372+ ggml_tensor * value =
373+ ggml_view_4d (ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2 ], num_k_heads, n_tokens, n_seqs,
374+ split_sizes_qkvz[2 ] * sizeof (float ), mixed_qkvz_reshaped->nb [1 ], mixed_qkvz_reshaped->nb [2 ],
375+ (split_sizes_qkvz[0 ] + split_sizes_qkvz[1 ]) * sizeof (float ));
376+ cb (value, " v" , il);
377+
378+ ggml_tensor * z =
379+ ggml_view_4d (ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3 ], num_k_heads, n_tokens, n_seqs,
380+ split_sizes_qkvz[3 ] * sizeof (float ), mixed_qkvz_reshaped->nb [1 ], mixed_qkvz_reshaped->nb [2 ],
381+ (split_sizes_qkvz[0 ] + split_sizes_qkvz[1 ] + split_sizes_qkvz[2 ]) * sizeof (float ));
382+ cb (z, " z" , il);
383+
384+ // Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
385+ ggml_tensor * value_reshaped =
386+ ggml_reshape_4d (ctx0, ggml_cont (ctx0, value), head_v_dim, num_v_heads, n_tokens, n_seqs);
387+ ggml_tensor * z_reshaped = ggml_reshape_4d (ctx0, ggml_cont (ctx0, z), head_v_dim, num_v_heads, n_tokens, n_seqs);
388+
389+ GGML_ASSERT (ggml_nelements (query) + ggml_nelements (key) + ggml_nelements (value_reshaped) +
390+ ggml_nelements (z_reshaped) ==
391+ ggml_nelements (mixed_qkvz));
392+
393393 // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
394394 // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
395395 ggml_tensor * query_flat = ggml_reshape_3d (ctx0, query, head_k_dim * num_k_heads, n_tokens, n_seqs);
0 commit comments