@@ -386,14 +386,13 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
386386 const auto * mctx_cur = inp->mctx ;
387387
388388 const int64_t d_inner = hparams.ssm_d_inner ;
389- const int64_t n_heads = hparams.ssm_dt_rank ;
390- const int64_t head_dim = d_inner / n_heads;
389+
391390 const int64_t n_seqs = ubatch.n_seqs ;
392391
393392 const int64_t head_k_dim = hparams.ssm_d_state ;
394- const int64_t head_v_dim = hparams.ssm_d_state ;
395393 const int64_t num_k_heads = hparams.ssm_n_group ;
396394 const int64_t num_v_heads = hparams.ssm_dt_rank ;
395+ const int64_t head_v_dim = d_inner / num_v_heads;
397396
398397 const int64_t n_seq_tokens = ubatch.n_seq_tokens ;
399398
@@ -408,7 +407,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
408407 ggml_tensor * mixed_ba = build_lora_mm (model.layers [il].ssm_beta_alpha , cur);
409408 cb (mixed_ba, " linear_attn_mixed_ba" , il);
410409
411- int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads;
410+ int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * ( num_v_heads / num_k_heads) ;
412411 ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d (ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
413412
414413 // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
@@ -441,63 +440,58 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
441440 ggml_tensor * gate = ggml_mul (ctx0, alpha_softplus, model.layers [il].ssm_a ); // -A_log.exp() * softplus
442441 cb (gate, " gate" , il);
443442
444- // Get convolution states from cache
445- ggml_tensor * conv_states_all = mctx_cur->get_r_l (il);
446- ggml_tensor * ssm_states_all = mctx_cur->get_s_l (il);
447-
448- // Build the convolution states tensor
449- ggml_tensor * conv_states = build_rs (inp, conv_states_all, hparams.n_embd_r (), n_seqs);
450- cb (conv_states, " conv_states" , il);
451-
452- // Split mixed_qkvz into query, key, value, z
443+ // Split mixed_qkvz into query, key, value, z
453444 int64_t split_sizes_qkvz[4 ] = {
454445 head_k_dim, // query size
455446 head_k_dim, // key size
456447 head_v_dim * num_v_heads / num_k_heads, // value size
457448 head_v_dim * num_v_heads / num_k_heads // z size
458449 };
459450
460- ggml_tensor * query = ggml_cont (ctx0, ggml_view_4d (ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0 ], num_k_heads, n_seq_tokens, n_seqs,
461- mixed_qkvz_reshaped->nb [1 ], mixed_qkvz_reshaped->nb [2 ], mixed_qkvz_reshaped->nb [3 ], 0 )) ;
451+ ggml_tensor * query = ggml_view_4d (ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0 ], num_k_heads, n_seq_tokens, n_seqs,
452+ mixed_qkvz_reshaped->nb [1 ], mixed_qkvz_reshaped->nb [2 ], mixed_qkvz_reshaped->nb [3 ], 0 );
462453 cb (query, " q" , il);
463454
464- ggml_tensor * key = ggml_cont (ctx0, ggml_view_4d (ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1 ], num_k_heads, n_seq_tokens, n_seqs,
455+ ggml_tensor * key = ggml_view_4d (ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1 ], num_k_heads, n_seq_tokens, n_seqs,
465456 mixed_qkvz_reshaped->nb [1 ], mixed_qkvz_reshaped->nb [2 ], mixed_qkvz_reshaped->nb [3 ],
466- split_sizes_qkvz[0 ] * sizeof (float ))) ;
457+ split_sizes_qkvz[0 ] * sizeof (float ));
467458 cb (key, " k" , il);
468459
469- ggml_tensor * value = ggml_cont (ctx0, ggml_view_4d (ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2 ], num_k_heads, n_seq_tokens, n_seqs,
460+ ggml_tensor * value = ggml_view_4d (ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2 ], num_k_heads, n_seq_tokens, n_seqs,
470461 mixed_qkvz_reshaped->nb [1 ], mixed_qkvz_reshaped->nb [2 ], mixed_qkvz_reshaped->nb [3 ],
471- (split_sizes_qkvz[0 ] + split_sizes_qkvz[1 ]) * sizeof (float ))) ;
462+ (split_sizes_qkvz[0 ] + split_sizes_qkvz[1 ]) * sizeof (float ));
472463 cb (value, " v" , il);
473464
474- ggml_tensor * z = ggml_cont (ctx0, ggml_view_4d (ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3 ], num_k_heads, n_seq_tokens, n_seqs,
465+ ggml_tensor * z = ggml_view_4d (ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3 ], num_k_heads, n_seq_tokens, n_seqs,
475466 mixed_qkvz_reshaped->nb [1 ], mixed_qkvz_reshaped->nb [2 ], mixed_qkvz_reshaped->nb [3 ],
476- (split_sizes_qkvz[0 ] + split_sizes_qkvz[1 ] + split_sizes_qkvz[2 ]) * sizeof (float ))) ;
467+ (split_sizes_qkvz[0 ] + split_sizes_qkvz[1 ] + split_sizes_qkvz[2 ]) * sizeof (float ));
477468 cb (z, " z" , il);
478469
479- // Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
480- ggml_tensor * value_reshaped =
481- ggml_reshape_4d (ctx0, ggml_cont (ctx0, value), head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
482- ggml_tensor * z_reshaped = ggml_reshape_4d (ctx0, ggml_cont (ctx0, z), head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
483-
484- GGML_ASSERT (ggml_nelements (query) + ggml_nelements (key) + ggml_nelements (value_reshaped) +
485- ggml_nelements (z_reshaped) ==
470+ GGML_ASSERT (ggml_nelements (query) + ggml_nelements (key) + ggml_nelements (value) +
471+ ggml_nelements (z) ==
486472 ggml_nelements (mixed_qkvz));
487473
488474 // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
489475 // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
490- ggml_tensor * query_flat = ggml_reshape_3d (ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
476+ ggml_tensor * query_flat = ggml_cont_3d (ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
491477 cb (query_flat, " query_flat" , il);
492478
493479 // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
494- ggml_tensor * key_flat = ggml_reshape_3d (ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
480+ ggml_tensor * key_flat = ggml_cont_3d (ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
495481 cb (key_flat, " key_flat" , il);
496482
497483 // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
498- ggml_tensor * value_flat = ggml_reshape_3d (ctx0, value_reshaped , head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
484+ ggml_tensor * value_flat = ggml_cont_3d (ctx0, value , head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
499485 cb (value_flat, " value_flat" , il);
500486
487+ // Get convolution states from cache
488+ ggml_tensor * conv_states_all = mctx_cur->get_r_l (il);
489+ ggml_tensor * ssm_states_all = mctx_cur->get_s_l (il);
490+
491+ // Build the convolution states tensor
492+ ggml_tensor * conv_states = build_rs (inp, conv_states_all, hparams.n_embd_r (), n_seqs);
493+ cb (conv_states, " conv_states" , il);
494+
501495 // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
502496 ggml_tensor * qkv_mixed = ggml_concat (ctx0, query_flat, key_flat, 0 );
503497 qkv_mixed = ggml_concat (ctx0, qkv_mixed, value_flat, 0 );
@@ -578,7 +572,7 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
578572 beta = ggml_cont_4d (ctx0, b, num_v_heads, 1 , n_seq_tokens, n_seqs);
579573
580574 ggml_tensor * state = build_rs (inp, ssm_states_all, hparams.n_embd_s (), n_seqs);
581- state = ggml_reshape_4d (ctx0, state, head_dim, head_dim * n_heads , 1 , n_seqs);
575+ state = ggml_reshape_4d (ctx0, state, head_v_dim, head_v_dim * num_v_heads , 1 , n_seqs);
582576
583577 // if head keys and value keys are different, repeat to force tensors into matching shapes
584578 if (num_k_heads != num_v_heads) {
@@ -598,17 +592,17 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
598592 cb (attn_out, " attn_out" , il);
599593
600594 // The tensors were concatenated 1d, so we need to extract them 1d as well
601- const int64_t output_flat_size = head_dim * n_heads * n_seq_tokens * n_seqs;
595+ const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs;
602596 ggml_tensor * attn_out_1d =
603597 ggml_view_1d (ctx0, attn_out, output_flat_size, 0 );
604598 cb (attn_out_1d, " attn_out_1d" , il);
605599
606- ggml_tensor * attn_out_final = ggml_cont (ctx0, ggml_permute (ctx0, ggml_cont_4d (ctx0, attn_out_1d, head_dim , n_seq_tokens, n_heads , n_seqs), 0 , 2 , 1 , 3 ));
600+ ggml_tensor * attn_out_final = ggml_cont (ctx0, ggml_permute (ctx0, ggml_cont_4d (ctx0, attn_out_1d, head_v_dim , n_seq_tokens, num_v_heads , n_seqs), 0 , 2 , 1 , 3 ));
607601 cb (attn_out_final, " attn_out_final" , il);
608602
609603 // Extract the state part (second part of the concatenated tensor)
610604 // State starts after n_tokens elements along dimension 1
611- const int64_t state_flat_size = head_dim * head_dim * n_heads * n_seqs;
605+ const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs;
612606
613607 ggml_tensor * state_1d = ggml_view_1d (ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size (attn_out));
614608 cb (state_1d, " state_1d" , il);
@@ -620,19 +614,19 @@ ggml_tensor * llm_build_qwen3next::build_qwen3next_linear_attn_layer(llm_graph_i
620614
621615 // Reshape both attn_out_final and z to 2D tensors for normalization
622616 // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
623- ggml_tensor * attn_out_2d_final = ggml_reshape_2d (ctx0, ggml_cont (ctx0, attn_out_final), head_dim, n_heads * n_seq_tokens * n_seqs);
617+ ggml_tensor * attn_out_2d_final = ggml_reshape_2d (ctx0, ggml_cont (ctx0, attn_out_final), head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
624618
625619 // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
626- ggml_tensor * z_2d = ggml_reshape_2d (ctx0, z_reshaped, head_dim, n_heads * n_seq_tokens * n_seqs);
620+ ggml_tensor * z_2d = ggml_cont_2d (ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
627621
628622 // Apply gated normalization: self.norm(core_attn_out, z)
629623 ggml_tensor * attn_out_norm = build_q3n_gated_norm (attn_out_2d_final, model.layers [il].ssm_norm , z_2d, il);
630624
631625 // Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
632- ggml_tensor * gated_output_4d = ggml_reshape_4d (ctx0, attn_out_norm, head_dim, n_heads , n_seq_tokens, n_seqs);
626+ ggml_tensor * gated_output_4d = ggml_reshape_4d (ctx0, attn_out_norm, head_v_dim, num_v_heads , n_seq_tokens, n_seqs);
633627
634628 // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
635- ggml_tensor * final_output = ggml_reshape_3d (ctx0, gated_output_4d, n_heads * head_dim , n_seq_tokens, n_seqs);
629+ ggml_tensor * final_output = ggml_reshape_3d (ctx0, gated_output_4d, head_v_dim * num_v_heads , n_seq_tokens, n_seqs);
636630 cb (final_output, " final_output" , il);
637631
638632 // Output projection
0 commit comments