@@ -982,9 +982,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
982982 float kq_scale) const {
983983 const bool v_trans = v->nb [1 ] > v->nb [2 ];
984984
985- const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
985+ // split the batch into streams if needed
986+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq ;
986987
987- q = ggml_reshape_4d (ctx0, q, q->ne [0 ], q->ne [1 ], q->ne [2 ]/n_seqs, n_seqs );
988+ q = ggml_reshape_4d (ctx0, q, q->ne [0 ], q->ne [1 ], q->ne [2 ]/n_stream, n_stream );
988989
989990 q = ggml_permute (ctx0, q, 0 , 2 , 1 , 3 );
990991 k = ggml_permute (ctx0, k, 0 , 2 , 1 , 3 );
@@ -1034,7 +1035,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10341035#endif
10351036 }
10361037
1037- cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_seqs);
1038+ // recombine streams
1039+ cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_stream);
10381040 } else {
10391041 ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
10401042
@@ -1079,7 +1081,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10791081
10801082 cur = ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 );
10811083
1082- cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_seqs);
1084+ // recombine streams
1085+ cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_stream);
10831086
10841087 if (!cparams.offload_kqv ) {
10851088 // all nodes between the KV store and the attention output are run on the CPU
@@ -1162,12 +1165,12 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
11621165
11631166 const auto n_kv = mctx_cur->get_n_kv ();
11641167 const auto n_tokens = ubatch.n_tokens ;
1165- const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1168+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq ;
11661169
11671170 inp->self_k_idxs = mctx_cur->build_input_k_idxs (ctx0, ubatch);
11681171 inp->self_v_idxs = mctx_cur->build_input_v_idxs (ctx0, ubatch);
11691172
1170- inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
1173+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_stream , GGML_KQ_MASK_PAD), 1 , n_stream );
11711174 ggml_set_input (inp->self_kq_mask );
11721175
11731176 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1367,15 +1370,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
13671370
13681371 auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
13691372
1370- const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1373+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq ;
13711374
13721375 {
13731376 const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
13741377
13751378 inp->self_k_idxs = mctx_cur->get_base ()->build_input_k_idxs (ctx0, ubatch);
13761379 inp->self_v_idxs = mctx_cur->get_base ()->build_input_v_idxs (ctx0, ubatch);
13771380
1378- inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
1381+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_stream , GGML_KQ_MASK_PAD), 1 , n_stream );
13791382 ggml_set_input (inp->self_kq_mask );
13801383
13811384 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1389,7 +1392,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
13891392 inp->self_k_idxs_swa = mctx_cur->get_swa ()->build_input_k_idxs (ctx0, ubatch);
13901393 inp->self_v_idxs_swa = mctx_cur->get_swa ()->build_input_v_idxs (ctx0, ubatch);
13911394
1392- inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
1395+ inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_stream , GGML_KQ_MASK_PAD), 1 , n_stream );
13931396 ggml_set_input (inp->self_kq_mask_swa );
13941397
13951398 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments