Skip to content

Commit fb512aa

Browse files
committed
models : fix Phi-3 SWA parameters
ggml-ci
1 parent 847e9c8 commit fb512aa

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

src/llama-graph.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,9 +1316,8 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
13161316
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
13171317
}
13181318

1319-
{
1320-
GGML_ASSERT(hparams.n_swa_pattern > 1 && "Use llama_kv_cache_unified for non-SWA");
1321-
GGML_ASSERT(hparams.n_swa > 0 && "Use llama_kv_cache_unified for non-SWA");
1319+
if (hparams.n_swa_pattern > 1) {
1320+
GGML_ASSERT(hparams.n_swa > 0 && "Use llama_kv_cache_unified for non-SWA");
13221321

13231322
const auto n_kv = kv_self->get_kv_swa()->get_n();
13241323

src/llama-graph.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,10 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
256256

257257
void set_input(const llama_ubatch * ubatch) override;
258258

259-
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
259+
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
260260

261-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
262-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
261+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
262+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
263263

264264
const llama_hparams & hparams;
265265
const llama_cparams & cparams;

src/llama-model.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -856,20 +856,34 @@ void llama_model::load_hparams(llama_model_loader & ml) {
856856
// for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931
857857
if ((hparams.n_layer == 32 || hparams.n_layer == 40) && hparams.n_ctx_train == 4096) {
858858
// default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct
859+
LLAMA_LOG_WARN("%s: assuming n_swa = 2047 for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct\n", __func__);
860+
859861
hparams.n_swa = 2047;
860862
} else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) {
861863
// default value for Phi-3-mini-128k-instruct
862-
// note: this seems incorrect because the window is bigger than the train context?
863-
hparams.n_swa = 262144;
864+
LLAMA_LOG_WARN("%s: assuming n_swa = n_ctx_train for Phi-3-mini-128k-instruct\n", __func__);
865+
866+
hparams.n_swa = hparams.n_ctx_train;
867+
hparams.n_swa_pattern = 1;
864868
} else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) {
865869
// default value for Phi-3-medium-128k-instruct
866-
// note: this seems incorrect because the window is equal to the train context?
867-
hparams.n_swa = 131072;
870+
LLAMA_LOG_WARN("%s: assuming n_swa = n_ctx_train for Phi-3-medium-128k-instruct\n", __func__);
871+
872+
hparams.n_swa = hparams.n_ctx_train;
873+
hparams.n_swa_pattern = 1;
868874
}
875+
869876
bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
870877
if (!found_swa && hparams.n_swa == 0) {
871878
throw std::runtime_error("invalid value for sliding_window");
872879
}
880+
881+
if (hparams.n_swa > hparams.n_ctx_train) {
882+
LLAMA_LOG_WARN("%s: unexpected n_swa: %d >= %d, setting to 0\n", __func__, hparams.n_swa, hparams.n_ctx_train);
883+
884+
hparams.n_swa = hparams.n_ctx_train;
885+
hparams.n_swa_pattern = 1;
886+
}
873887
} break;
874888
case LLM_ARCH_PHIMOE:
875889
{

0 commit comments

Comments
 (0)