@@ -577,7 +577,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
577577 n_embd_head_v (hparams.n_embd_head_v),
578578 n_embd_v_gqa (hparams.n_embd_v_gqa()),
579579 n_expert (hparams.n_expert),
580- n_expert_used (hparams.n_expert_used),
580+ n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
581581 freq_base (cparams.rope_freq_base),
582582 freq_scale (cparams.rope_freq_scale),
583583 ext_factor (cparams.yarn_ext_factor),
@@ -1311,29 +1311,23 @@ ggml_tensor * llm_graph_context::build_attn(
13111311 return cur;
13121312}
13131313
1314- llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified (
1315- bool causal,
1316- bool swa) const {
1314+ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified () const {
13171315 const llama_kv_cache_unified * kv_self = static_cast <const llama_kv_cache_unified *>(memory);
13181316
13191317 auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
13201318
13211319 const auto n_kv = kv_self->n ;
13221320
1323- inp->self_kq_mask = causal
1324- ? ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD))
1325- : ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1321+ inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
13261322 // cb(inp->self_kq_mask, "KQ_mask", -1);
13271323 ggml_set_input (inp->self_kq_mask );
13281324
13291325 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
13301326
1331- if (swa ) {
1327+ if (hparams. n_swa_pattern > 1 ) {
13321328 GGML_ASSERT (hparams.n_swa > 0 );
13331329
1334- inp->self_kq_mask_swa = causal
1335- ? ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD))
1336- : ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1330+ inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
13371331 // cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
13381332 ggml_set_input (inp->self_kq_mask_swa );
13391333
@@ -1403,9 +1397,9 @@ ggml_tensor * llm_graph_context::build_attn(
14031397 ggml_build_forward_expand (gf, ggml_cpy (ctx0, v_cur, v_cache_view));
14041398 }
14051399
1406- const bool is_sliding = hparams.is_sliding (il);
1400+ const bool is_swa = hparams.is_swa (il);
14071401
1408- const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
1402+ const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
14091403
14101404 const auto n_kv = kv_self->n ;
14111405
0 commit comments