@@ -1009,8 +1009,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10091009 inp->self_k_idxs = mctx_cur->get_attn ()->build_input_k_idxs (ctx0, ubatch);
10101010 inp->self_v_idxs = mctx_cur->get_attn ()->build_input_v_idxs (ctx0, ubatch);
10111011
1012- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1013- // cb(inp->self_kq_mask, "KQ_mask", -1);
1012+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
10141013 ggml_set_input (inp->self_kq_mask );
10151014
10161015 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1147,8 +1146,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
11471146 auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
11481147
11491148 // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1150- inp->kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1151- // cb(inp_kq_mask, "KQ_mask", -1);
1149+ inp->kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
11521150 ggml_set_input (inp->kq_mask );
11531151
11541152 inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->kq_mask , GGML_TYPE_F16) : inp->kq_mask ;
@@ -1213,7 +1211,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12131211 inp->self_k_idxs = mctx_cur->build_input_k_idxs (ctx0, ubatch);
12141212 inp->self_v_idxs = mctx_cur->build_input_v_idxs (ctx0, ubatch);
12151213
1216- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1214+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
12171215 ggml_set_input (inp->self_kq_mask );
12181216
12191217 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1347,7 +1345,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
13471345
13481346 const int32_t n_enc = !cross->v_embd .empty () ? cross->n_enc : hparams.n_ctx_train ;
13491347
1350- inp->cross_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_enc, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1348+ inp->cross_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_enc, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
13511349 ggml_set_input (inp->cross_kq_mask );
13521350
13531351 inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->cross_kq_mask , GGML_TYPE_F16) : inp->cross_kq_mask ;
@@ -1461,7 +1459,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14611459 inp->self_k_idxs = mctx_cur->get_base ()->build_input_k_idxs (ctx0, ubatch);
14621460 inp->self_v_idxs = mctx_cur->get_base ()->build_input_v_idxs (ctx0, ubatch);
14631461
1464- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1462+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
14651463 ggml_set_input (inp->self_kq_mask );
14661464
14671465 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1475,7 +1473,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14751473 inp->self_k_idxs_swa = mctx_cur->get_swa ()->build_input_k_idxs (ctx0, ubatch);
14761474 inp->self_v_idxs_swa = mctx_cur->get_swa ()->build_input_v_idxs (ctx0, ubatch);
14771475
1478- inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1476+ inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
14791477 ggml_set_input (inp->self_kq_mask_swa );
14801478
14811479 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments