Skip to content

Commit fbc6df0

Browse files
committed
cont : consistent attention scaling
1 parent 67c4346 commit fbc6df0

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/llama-model.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
953953
case 46: type = LLM_TYPE_27B; break;
954954
default: type = LLM_TYPE_UNKNOWN;
955955
}
956+
957+
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173
958+
hparams.f_attention_scale = type == LLM_TYPE_27B
959+
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
960+
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
956961
} break;
957962
case LLM_ARCH_GEMMA3:
958963
{
@@ -973,6 +978,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
973978
default: type = LLM_TYPE_UNKNOWN;
974979
}
975980

981+
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289
976982
hparams.f_attention_scale = type == LLM_TYPE_27B
977983
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
978984
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
@@ -8481,14 +8487,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
84818487
cb(Kcur, "Kcur", il);
84828488
cb(Vcur, "Vcur", il);
84838489

8484-
// ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
8485-
switch (model.type) {
8486-
case LLM_TYPE_2B:
8487-
case LLM_TYPE_9B:
8488-
case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break;
8489-
default: GGML_ABORT("fatal error");
8490-
};
8491-
cb(Qcur, "Qcur_scaled", il);
8490+
Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
84928491

84938492
cur = build_attn(inp_attn, gf,
84948493
model.layers[il].wo, NULL,
@@ -8629,6 +8628,7 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
86298628
cb(Kcur, "Kcur", il);
86308629
cb(Vcur, "Vcur", il);
86318630

8631+
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
86328632
Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
86338633

86348634
cur = build_attn(inp_attn, gf,

0 commit comments

Comments
 (0)