@@ -1311,29 +1311,23 @@ ggml_tensor * llm_graph_context::build_attn(
13111311 return cur;
13121312}
13131313
1314- llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified (
1315- bool causal,
1316- bool swa) const {
1314+ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified () const {
13171315 const llama_kv_cache_unified * kv_self = static_cast <const llama_kv_cache_unified *>(memory);
13181316
13191317 auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
13201318
13211319 const auto n_kv = kv_self->n ;
13221320
1323- inp->self_kq_mask = causal
1324- ? ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD))
1325- : ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1321+ inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
13261322 // cb(inp->self_kq_mask, "KQ_mask", -1);
13271323 ggml_set_input (inp->self_kq_mask );
13281324
13291325 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
13301326
1331- if (swa ) {
1327+ if (hparams. n_swa_pattern > 1 ) {
13321328 GGML_ASSERT (hparams.n_swa > 0 );
13331329
1334- inp->self_kq_mask_swa = causal
1335- ? ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD))
1336- : ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1330+ inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
13371331 // cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
13381332 ggml_set_input (inp->self_kq_mask_swa );
13391333
0 commit comments