Skip to content

Commit 7604ddd

Browse files
committed
models : fix Phi-3 SWA parameters
ggml-ci
1 parent 8d8c8c4 commit 7604ddd

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

src/llama-graph.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,9 +1315,8 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
13151315
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
13161316
}
13171317

1318-
{
1319-
GGML_ASSERT(hparams.n_swa_pattern > 1 && "Use llama_kv_cache_unified for non-SWA");
1320-
GGML_ASSERT(hparams.n_swa > 0 && "Use llama_kv_cache_unified for non-SWA");
1318+
if (hparams.n_swa_pattern > 1) {
1319+
GGML_ASSERT(hparams.n_swa > 0 && "Use llama_kv_cache_unified for non-SWA");
13211320

13221321
const auto n_kv = kv_self->get_kv_swa()->get_n();
13231322

src/llama-graph.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,10 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
256256

257257
void set_input(const llama_ubatch * ubatch) override;
258258

259-
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
259+
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
260260

261-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
262-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
261+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
262+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
263263

264264
const llama_hparams & hparams;
265265
const llama_cparams & cparams;

src/llama-model.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -852,20 +852,34 @@ void llama_model::load_hparams(llama_model_loader & ml) {
852852
// for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931
853853
if ((hparams.n_layer == 32 || hparams.n_layer == 40) && hparams.n_ctx_train == 4096) {
854854
// default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct
855+
LLAMA_LOG_WARN("%s: assuming n_swa = 2047 for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct\n", __func__);
856+
855857
hparams.n_swa = 2047;
856858
} else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) {
857859
// default value for Phi-3-mini-128k-instruct
858-
// note: this seems incorrect because the window is bigger than the train context?
859-
hparams.n_swa = 262144;
860+
LLAMA_LOG_WARN("%s: assuming n_swa = n_ctx_train for Phi-3-mini-128k-instruct\n", __func__);
861+
862+
hparams.n_swa = hparams.n_ctx_train;
863+
hparams.n_swa_pattern = 1;
860864
} else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) {
861865
// default value for Phi-3-medium-128k-instruct
862-
// note: this seems incorrect because the window is equal to the train context?
863-
hparams.n_swa = 131072;
866+
LLAMA_LOG_WARN("%s: assuming n_swa = n_ctx_train for Phi-3-medium-128k-instruct\n", __func__);
867+
868+
hparams.n_swa = hparams.n_ctx_train;
869+
hparams.n_swa_pattern = 1;
864870
}
871+
865872
bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
866873
if (!found_swa && hparams.n_swa == 0) {
867874
throw std::runtime_error("invalid value for sliding_window");
868875
}
876+
877+
if (hparams.n_swa > hparams.n_ctx_train) {
878+
LLAMA_LOG_WARN("%s: unexpected n_swa: %d >= %d, setting to 0\n", __func__, hparams.n_swa, hparams.n_ctx_train);
879+
880+
hparams.n_swa = hparams.n_ctx_train;
881+
hparams.n_swa_pattern = 1;
882+
}
869883
} break;
870884
case LLM_ARCH_PHIMOE:
871885
{

0 commit comments

Comments
 (0)