Skip to content

Commit 8dc6649

Browse files
committed
Initial support for Gemma 3 models
(parcially?) resolves #711
1 parent 9d92413 commit 8dc6649

File tree

1 file changed

+225
-0
lines changed

1 file changed

+225
-0
lines changed

llama.cpp/llama.cpp

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ enum llm_arch {
162162
LLM_ARCH_MINICPM,
163163
LLM_ARCH_GEMMA,
164164
LLM_ARCH_GEMMA2,
165+
LLM_ARCH_GEMMA3,
165166
LLM_ARCH_STARCODER2,
166167
LLM_ARCH_MAMBA,
167168
LLM_ARCH_XVERSE,
@@ -209,6 +210,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
209210
{ LLM_ARCH_MINICPM, "minicpm" },
210211
{ LLM_ARCH_GEMMA, "gemma" },
211212
{ LLM_ARCH_GEMMA2, "gemma2" },
213+
{ LLM_ARCH_GEMMA3, "gemma3" },
212214
{ LLM_ARCH_STARCODER2, "starcoder2" },
213215
{ LLM_ARCH_MAMBA, "mamba" },
214216
{ LLM_ARCH_XVERSE, "xverse" },
@@ -998,6 +1000,26 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
9981000
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
9991001
},
10001002
},
1003+
{
1004+
LLM_ARCH_GEMMA3,
1005+
{
1006+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1007+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1008+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1009+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1010+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1011+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1012+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1013+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1014+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1015+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
1016+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1017+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1018+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1019+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1020+
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
1021+
},
1022+
},
10011023
{
10021024
LLM_ARCH_STARCODER2,
10031025
{
@@ -1917,7 +1939,9 @@ struct llama_hparams {
19171939

19181940
float rope_attn_factor = 1.0f;
19191941
float rope_freq_base_train;
1942+
float rope_freq_base_train_swa;
19201943
float rope_freq_scale_train;
1944+
float rope_freq_scale_train_swa;
19211945
uint32_t n_ctx_orig_yarn;
19221946
float rope_yarn_log_mul;
19231947

@@ -1931,6 +1955,8 @@ struct llama_hparams {
19311955
float f_max_alibi_bias = 0.0f;
19321956
float f_logit_scale = 0.0f;
19331957

1958+
float f_attention_scale = 0.0f;
1959+
19341960
bool causal_attn = true;
19351961
bool use_alibi = false;
19361962
bool attn_soft_cap = false;
@@ -4393,6 +4419,10 @@ static void llm_load_hparams(
43934419
}
43944420
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
43954421

4422+
// by default assume that the sliding-window layers use the same scaling type as the non-sliding-window layers
4423+
hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
4424+
hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
4425+
43964426
ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
43974427

43984428
// non-transformer models do not have attention heads
@@ -4739,6 +4769,28 @@ static void llm_load_hparams(
47394769
default: model.type = e_model::MODEL_UNKNOWN;
47404770
}
47414771
} break;
4772+
case LLM_ARCH_GEMMA3:
4773+
{
4774+
hparams.n_swa = 1024;
4775+
4776+
hparams.rope_freq_base_train_swa = 10000.0f;
4777+
hparams.rope_freq_scale_train_swa = 1.0f;
4778+
4779+
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
4780+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
4781+
4782+
switch (hparams.n_layer) {
4783+
case 26: model.type = e_model::MODEL_1B; break;
4784+
case 34: model.type = e_model::MODEL_4B; break;
4785+
case 48: model.type = e_model::MODEL_12B; break;
4786+
case 62: model.type = e_model::MODEL_27B; break;
4787+
default: model.type = e_model::MODEL_UNKNOWN;
4788+
}
4789+
4790+
hparams.f_attention_scale = model.type == e_model::MODEL_27B
4791+
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
4792+
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
4793+
} break;
47424794
case LLM_ARCH_STARCODER2:
47434795
{
47444796
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -6767,6 +6819,38 @@ static bool llm_load_tensors(
67676819
layer.ffn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd});
67686820
}
67696821
} break;
6822+
case LLM_ARCH_GEMMA3:
6823+
{
6824+
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
6825+
6826+
// output
6827+
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
6828+
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
6829+
6830+
for (int i = 0; i < n_layer; ++i) {
6831+
ggml_context * ctx_layer = ctx_for_layer(i);
6832+
ggml_context * ctx_split = ctx_for_layer_split(i);
6833+
6834+
auto & layer = model.layers[i];
6835+
6836+
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
6837+
6838+
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
6839+
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
6840+
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
6841+
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
6842+
6843+
layer.attn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
6844+
layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
6845+
layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
6846+
6847+
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
6848+
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
6849+
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
6850+
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
6851+
layer.ffn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
6852+
}
6853+
} break;
67706854
case LLM_ARCH_STARCODER2:
67716855
{
67726856
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -11740,6 +11824,142 @@ struct llm_build_context {
1174011824
return gf;
1174111825
}
1174211826

11827+
struct ggml_cgraph * build_gemma3() {
11828+
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
11829+
11830+
const int64_t n_embd_head_k = hparams.n_embd_head_k;
11831+
11832+
struct ggml_tensor * cur;
11833+
struct ggml_tensor * inpL;
11834+
11835+
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
11836+
11837+
// TODO: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
11838+
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
11839+
cb(inpL, "inp_scaled", -1);
11840+
11841+
// inp_pos - contains the positions
11842+
struct ggml_tensor * inp_pos = build_inp_pos();
11843+
11844+
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true);
11845+
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
11846+
11847+
for (int il = 0; il < n_layer; ++il) {
11848+
const bool is_swa = il % 6 < 5;
11849+
11850+
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
11851+
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
11852+
11853+
struct ggml_tensor * KQ_mask_l = is_swa ? KQ_mask_swa : KQ_mask;
11854+
11855+
// norm
11856+
cur = llm_build_norm(ctx0, inpL, hparams,
11857+
model.layers[il].attn_norm, NULL,
11858+
LLM_NORM_RMS, cb, il);
11859+
cb(cur, "attn_norm", il);
11860+
11861+
// self-attention
11862+
{
11863+
// compute Q and K and RoPE them
11864+
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
11865+
cb(Qcur, "Qcur", il);
11866+
11867+
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
11868+
cb(Kcur, "Kcur", il);
11869+
11870+
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
11871+
cb(Vcur, "Vcur", il);
11872+
11873+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens);
11874+
Qcur = llm_build_norm(ctx0, Qcur, hparams,
11875+
model.layers[il].attn_q_norm, NULL,
11876+
LLM_NORM_RMS, cb, il);
11877+
cb(Qcur, "Qcur_normed", il);
11878+
11879+
Qcur = ggml_rope_ext(
11880+
ctx0, Qcur, inp_pos, nullptr,
11881+
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
11882+
ext_factor, attn_factor, beta_fast, beta_slow);
11883+
cb(Qcur, "Qcur", il);
11884+
11885+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens);
11886+
Kcur = llm_build_norm(ctx0, Kcur, hparams,
11887+
model.layers[il].attn_k_norm, NULL,
11888+
LLM_NORM_RMS, cb, il);
11889+
cb(Kcur, "Kcur_normed", il);
11890+
11891+
Kcur = ggml_rope_ext(
11892+
ctx0, Kcur, inp_pos, nullptr,
11893+
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
11894+
ext_factor, attn_factor, beta_fast, beta_slow);
11895+
cb(Kcur, "Kcur", il);
11896+
11897+
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
11898+
model.layers[il].wo, NULL,
11899+
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il);
11900+
}
11901+
11902+
cur = llm_build_norm(ctx0, cur, hparams,
11903+
model.layers[il].attn_post_norm, NULL,
11904+
LLM_NORM_RMS, cb, il);
11905+
cb(cur, "attn_post_norm", il);
11906+
11907+
if (il == n_layer - 1) {
11908+
// skip computing output for unused tokens
11909+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
11910+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11911+
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
11912+
}
11913+
11914+
struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
11915+
cb(sa_out, "sa_out", il);
11916+
11917+
cur = llm_build_norm(ctx0, sa_out, hparams,
11918+
model.layers[il].ffn_norm, NULL,
11919+
LLM_NORM_RMS, cb, il);
11920+
cb(cur, "ffn_norm", il);
11921+
11922+
// feed-forward network
11923+
{
11924+
cur = llm_build_ffn(ctx0, lctx, cur,
11925+
model.layers[il].ffn_up, NULL, NULL,
11926+
model.layers[il].ffn_gate, NULL, NULL,
11927+
model.layers[il].ffn_down, NULL, NULL,
11928+
NULL,
11929+
LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
11930+
cb(cur, "ffn_out", il);
11931+
}
11932+
11933+
cur = llm_build_norm(ctx0, cur, hparams,
11934+
model.layers[il].ffn_post_norm, NULL,
11935+
LLM_NORM_RMS, cb, -1);
11936+
cb(cur, "ffn_post_norm", -1);
11937+
11938+
cur = ggml_add(ctx0, cur, sa_out);
11939+
cur = lctx.cvec.apply_to(ctx0, cur, il);
11940+
cb(cur, "l_out", il);
11941+
11942+
// input for next layer
11943+
inpL = cur;
11944+
}
11945+
11946+
cur = inpL;
11947+
11948+
cur = llm_build_norm(ctx0, cur, hparams,
11949+
model.output_norm, NULL,
11950+
LLM_NORM_RMS, cb, -1);
11951+
11952+
cb(cur, "result_norm", -1);
11953+
11954+
// lm_head
11955+
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
11956+
11957+
cb(cur, "result_output", -1);
11958+
11959+
ggml_build_forward_expand(gf, cur);
11960+
11961+
return gf;
11962+
}
1174311963

1174411964
struct ggml_cgraph * build_starcoder2() {
1174511965
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
@@ -14035,6 +14255,10 @@ static struct ggml_cgraph * llama_build_graph(
1403514255
{
1403614256
result = llm.build_gemma2();
1403714257
} break;
14258+
case LLM_ARCH_GEMMA3:
14259+
{
14260+
result = llm.build_gemma3();
14261+
} break;
1403814262
case LLM_ARCH_STARCODER2:
1403914263
{
1404014264
result = llm.build_starcoder2();
@@ -17212,6 +17436,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1721217436
case LLM_ARCH_PHI3:
1721317437
case LLM_ARCH_GEMMA:
1721417438
case LLM_ARCH_GEMMA2:
17439+
case LLM_ARCH_GEMMA3:
1721517440
case LLM_ARCH_STARCODER2:
1721617441
case LLM_ARCH_OPENELM:
1721717442
case LLM_ARCH_GPTNEOX:

0 commit comments

Comments
 (0)