@@ -953,6 +953,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
953953 case 46: type = LLM_TYPE_27B; break;
954954 default: type = LLM_TYPE_UNKNOWN;
955955 }
956+
957+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173
958+ hparams.f_attention_scale = type == LLM_TYPE_27B
959+ ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
960+ : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
956961 } break;
957962 case LLM_ARCH_GEMMA3:
958963 {
@@ -973,6 +978,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
973978 default: type = LLM_TYPE_UNKNOWN;
974979 }
975980
981+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289
976982 hparams.f_attention_scale = type == LLM_TYPE_27B
977983 ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
978984 : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
@@ -8481,14 +8487,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
84818487 cb(Kcur, "Kcur", il);
84828488 cb(Vcur, "Vcur", il);
84838489
8484- // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
8485- switch (model.type) {
8486- case LLM_TYPE_2B:
8487- case LLM_TYPE_9B:
8488- case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break;
8489- default: GGML_ABORT("fatal error");
8490- };
8491- cb(Qcur, "Qcur_scaled", il);
8490+ Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
84928491
84938492 cur = build_attn(inp_attn, gf,
84948493 model.layers[il].wo, NULL,
@@ -8629,6 +8628,7 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
86298628 cb(Kcur, "Kcur", il);
86308629 cb(Vcur, "Vcur", il);
86318630
8631+ // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
86328632 Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
86338633
86348634 cur = build_attn(inp_attn, gf,
0 commit comments