@@ -1005,8 +1005,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10051005 inp->self_k_idxs = mctx_cur->get_attn ()->build_input_k_idxs (ctx0, ubatch);
10061006 inp->self_v_idxs = mctx_cur->get_attn ()->build_input_v_idxs (ctx0, ubatch);
10071007
1008- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1009- // cb(inp->self_kq_mask, "KQ_mask", -1);
1008+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
10101009 ggml_set_input (inp->self_kq_mask );
10111010
10121011 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1143,8 +1142,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
11431142 auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
11441143
11451144 // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1146- inp->kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1147- // cb(inp_kq_mask, "KQ_mask", -1);
1145+ inp->kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
11481146 ggml_set_input (inp->kq_mask );
11491147
11501148 inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->kq_mask , GGML_TYPE_F16) : inp->kq_mask ;
@@ -1209,7 +1207,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12091207 inp->self_k_idxs = mctx_cur->build_input_k_idxs (ctx0, ubatch);
12101208 inp->self_v_idxs = mctx_cur->build_input_v_idxs (ctx0, ubatch);
12111209
1212- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1210+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
12131211 ggml_set_input (inp->self_kq_mask );
12141212
12151213 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1343,7 +1341,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
13431341
13441342 const int32_t n_enc = !cross->v_embd .empty () ? cross->n_enc : hparams.n_ctx_train ;
13451343
1346- inp->cross_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_enc, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1344+ inp->cross_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_enc, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
13471345 ggml_set_input (inp->cross_kq_mask );
13481346
13491347 inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->cross_kq_mask , GGML_TYPE_F16) : inp->cross_kq_mask ;
@@ -1457,7 +1455,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14571455 inp->self_k_idxs = mctx_cur->get_base ()->build_input_k_idxs (ctx0, ubatch);
14581456 inp->self_v_idxs = mctx_cur->get_base ()->build_input_v_idxs (ctx0, ubatch);
14591457
1460- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1458+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
14611459 ggml_set_input (inp->self_kq_mask );
14621460
14631461 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1471,7 +1469,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14711469 inp->self_k_idxs_swa = mctx_cur->get_swa ()->build_input_k_idxs (ctx0, ubatch);
14721470 inp->self_v_idxs_swa = mctx_cur->get_swa ()->build_input_v_idxs (ctx0, ubatch);
14731471
1474- inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1472+ inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
14751473 ggml_set_input (inp->self_kq_mask_swa );
14761474
14771475 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments