@@ -982,6 +982,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
982982 float kq_scale) const {
983983 const bool v_trans = v->nb [1 ] > v->nb [2 ];
984984
985+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
986+
987+ q = ggml_reshape_4d (ctx0, q, q->ne [0 ], q->ne [1 ], q->ne [2 ]/n_seqs, n_seqs);
988+
985989 q = ggml_permute (ctx0, q, 0 , 2 , 1 , 3 );
986990 k = ggml_permute (ctx0, k, 0 , 2 , 1 , 3 );
987991 v = ggml_permute (ctx0, v, 0 , 2 , 1 , 3 );
@@ -1030,7 +1034,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10301034#endif
10311035 }
10321036
1033- cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens);
1037+ cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_seqs );
10341038 } else {
10351039 ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
10361040
@@ -1075,7 +1079,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
10751079
10761080 cur = ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 );
10771081
1078- cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens);
1082+ cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_seqs );
10791083
10801084 if (!cparams.offload_kqv ) {
10811085 // all nodes between the KV store and the attention output are run on the CPU
@@ -1156,13 +1160,14 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
11561160 {
11571161 GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified_iswa for SWA" );
11581162
1159- const auto n_kv = mctx_cur->get_n_kv ();
1163+ const auto n_kv = mctx_cur->get_n_kv ();
11601164 const auto n_tokens = ubatch.n_tokens ;
1165+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
11611166
11621167 inp->self_k_idxs = mctx_cur->build_input_k_idxs (ctx0, ubatch);
11631168 inp->self_v_idxs = mctx_cur->build_input_v_idxs (ctx0, ubatch);
11641169
1165- inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1170+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
11661171 ggml_set_input (inp->self_kq_mask );
11671172
11681173 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1362,13 +1367,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
13621367
13631368 auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
13641369
1370+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1371+
13651372 {
13661373 const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
13671374
13681375 inp->self_k_idxs = mctx_cur->get_base ()->build_input_k_idxs (ctx0, ubatch);
13691376 inp->self_v_idxs = mctx_cur->get_base ()->build_input_v_idxs (ctx0, ubatch);
13701377
1371- inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1378+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
13721379 ggml_set_input (inp->self_kq_mask );
13731380
13741381 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1382,7 +1389,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
13821389 inp->self_k_idxs_swa = mctx_cur->get_swa ()->build_input_k_idxs (ctx0, ubatch);
13831390 inp->self_v_idxs_swa = mctx_cur->get_swa ()->build_input_v_idxs (ctx0, ubatch);
13841391
1385- inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1392+ inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
13861393 ggml_set_input (inp->self_kq_mask_swa );
13871394
13881395 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments