@@ -1000,13 +1000,13 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10001000 {
10011001 GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Hybrid recurrent is not supported with SWA attention layers" );
10021002
1003- const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
1004- const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1003+ const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
1004+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq ;
10051005
10061006 inp->self_k_idxs = mctx_cur->get_attn ()->build_input_k_idxs (ctx0, ubatch);
10071007 inp->self_v_idxs = mctx_cur->get_attn ()->build_input_v_idxs (ctx0, ubatch);
10081008
1009- inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
1009+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_stream , GGML_KQ_MASK_PAD), 1 , n_stream );
10101010 ggml_set_input (inp->self_kq_mask );
10111011
10121012 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1033,9 +1033,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10331033 float kq_scale) const {
10341034 const bool v_trans = v->nb [1 ] > v->nb [2 ];
10351035
1036- const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1036+ // split the batch into streams if needed
1037+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq ;
10371038
1038- q = ggml_reshape_4d (ctx0, q, q->ne [0 ], q->ne [1 ], q->ne [2 ]/n_seqs, n_seqs );
1039+ q = ggml_reshape_4d (ctx0, q, q->ne [0 ], q->ne [1 ], q->ne [2 ]/n_stream, n_stream );
10391040
10401041 q = ggml_permute (ctx0, q, 0 , 2 , 1 , 3 );
10411042 k = ggml_permute (ctx0, k, 0 , 2 , 1 , 3 );
@@ -1085,7 +1086,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10851086#endif
10861087 }
10871088
1088- cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_seqs);
1089+ // recombine streams
1090+ cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_stream);
10891091 } else {
10901092 ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
10911093
@@ -1130,7 +1132,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
11301132
11311133 cur = ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 );
11321134
1133- cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_seqs);
1135+ // recombine streams
1136+ cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_stream);
11341137
11351138 if (!cparams.offload_kqv ) {
11361139 // all nodes between the KV store and the attention output are run on the CPU
@@ -1207,13 +1210,13 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12071210 {
12081211 GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified_iswa for SWA" );
12091212
1210- const auto n_kv = mctx_cur->get_n_kv ();
1211- const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1213+ const auto n_kv = mctx_cur->get_n_kv ();
1214+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq ;
12121215
12131216 inp->self_k_idxs = mctx_cur->build_input_k_idxs (ctx0, ubatch);
12141217 inp->self_v_idxs = mctx_cur->build_input_v_idxs (ctx0, ubatch);
12151218
1216- inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
1219+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_stream , GGML_KQ_MASK_PAD), 1 , n_stream );
12171220 ggml_set_input (inp->self_kq_mask );
12181221
12191222 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1455,15 +1458,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14551458
14561459 auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
14571460
1458- const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1461+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq ;
14591462
14601463 {
14611464 const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
14621465
14631466 inp->self_k_idxs = mctx_cur->get_base ()->build_input_k_idxs (ctx0, ubatch);
14641467 inp->self_v_idxs = mctx_cur->get_base ()->build_input_v_idxs (ctx0, ubatch);
14651468
1466- inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
1469+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_stream , GGML_KQ_MASK_PAD), 1 , n_stream );
14671470 ggml_set_input (inp->self_kq_mask );
14681471
14691472 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1477,7 +1480,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14771480 inp->self_k_idxs_swa = mctx_cur->get_swa ()->build_input_k_idxs (ctx0, ubatch);
14781481 inp->self_v_idxs_swa = mctx_cur->get_swa ()->build_input_v_idxs (ctx0, ubatch);
14791482
1480- inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
1483+ inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_stream , GGML_KQ_MASK_PAD), 1 , n_stream );
14811484 ggml_set_input (inp->self_kq_mask_swa );
14821485
14831486 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments