Skip to content

Commit f6b3831

Browse files
committed
Initial support for Gemma 3 models
Tested only on text-to-text.
1 parent 17d7f4a commit f6b3831

File tree

1 file changed

+225
-0
lines changed

1 file changed

+225
-0
lines changed

llama.cpp/llama.cpp

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ enum llm_arch {
162162
LLM_ARCH_MINICPM,
163163
LLM_ARCH_GEMMA,
164164
LLM_ARCH_GEMMA2,
165+
LLM_ARCH_GEMMA3,
165166
LLM_ARCH_STARCODER2,
166167
LLM_ARCH_MAMBA,
167168
LLM_ARCH_XVERSE,
@@ -211,6 +212,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
211212
{ LLM_ARCH_MINICPM, "minicpm" },
212213
{ LLM_ARCH_GEMMA, "gemma" },
213214
{ LLM_ARCH_GEMMA2, "gemma2" },
215+
{ LLM_ARCH_GEMMA3, "gemma3" },
214216
{ LLM_ARCH_STARCODER2, "starcoder2" },
215217
{ LLM_ARCH_MAMBA, "mamba" },
216218
{ LLM_ARCH_XVERSE, "xverse" },
@@ -1008,6 +1010,26 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
10081010
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
10091011
},
10101012
},
1013+
{
1014+
LLM_ARCH_GEMMA3,
1015+
{
1016+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1017+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1018+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1019+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1020+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1021+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1022+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1023+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1024+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1025+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
1026+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1027+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1028+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1029+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1030+
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
1031+
},
1032+
},
10111033
{
10121034
LLM_ARCH_STARCODER2,
10131035
{
@@ -1962,7 +1984,9 @@ struct llama_hparams {
19621984

19631985
float rope_attn_factor = 1.0f;
19641986
float rope_freq_base_train;
1987+
float rope_freq_base_train_swa;
19651988
float rope_freq_scale_train;
1989+
float rope_freq_scale_train_swa;
19661990
uint32_t n_ctx_orig_yarn;
19671991
float rope_yarn_log_mul;
19681992

@@ -2035,6 +2059,8 @@ struct llama_hparams {
20352059
if (!is_float_close(this->rope_attn_factor, other.rope_attn_factor, EPSILON)) return true;
20362060
if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true;
20372061
if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true;
2062+
if (!is_float_close(this->rope_freq_base_train_swa, other.rope_freq_base_train_swa, EPSILON)) return true;
2063+
if (!is_float_close(this->rope_freq_scale_train_swa, other.rope_freq_scale_train_swa, EPSILON)) return true;
20382064
if (!is_float_close(this->expert_weights_scale, other.expert_weights_scale, EPSILON)) return true;
20392065
if (!is_float_close(this->rope_yarn_log_mul, other.rope_yarn_log_mul, EPSILON)) return true;
20402066
if (!is_float_close(this->f_residual_scale, other.f_residual_scale, EPSILON)) return true;
@@ -4446,6 +4472,10 @@ static void llm_load_hparams(
44464472
}
44474473
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
44484474

4475+
// by default assume that the sliding-window layers use the same scaling type as the non-sliding-window layers
4476+
hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
4477+
hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
4478+
44494479
ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
44504480

44514481
// non-transformer models do not have attention heads
@@ -4792,6 +4822,28 @@ static void llm_load_hparams(
47924822
default: model.type = e_model::MODEL_UNKNOWN;
47934823
}
47944824
} break;
4825+
case LLM_ARCH_GEMMA3:
4826+
{
4827+
hparams.n_swa = 1024;
4828+
4829+
hparams.rope_freq_base_train_swa = 10000.0f;
4830+
hparams.rope_freq_scale_train_swa = 1.0f;
4831+
4832+
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
4833+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
4834+
4835+
switch (hparams.n_layer) {
4836+
case 26: model.type = e_model::MODEL_1B; break;
4837+
case 34: model.type = e_model::MODEL_4B; break;
4838+
case 48: model.type = e_model::MODEL_12B; break;
4839+
case 62: model.type = e_model::MODEL_27B; break;
4840+
default: model.type = e_model::MODEL_UNKNOWN;
4841+
}
4842+
4843+
hparams.f_attention_scale = model.type == e_model::MODEL_27B
4844+
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
4845+
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
4846+
} break;
47954847
case LLM_ARCH_STARCODER2:
47964848
{
47974849
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -6844,6 +6896,38 @@ static bool llm_load_tensors(
68446896
layer.ffn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd});
68456897
}
68466898
} break;
6899+
case LLM_ARCH_GEMMA3:
6900+
{
6901+
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
6902+
6903+
// output
6904+
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
6905+
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
6906+
6907+
for (int i = 0; i < n_layer; ++i) {
6908+
ggml_context * ctx_layer = ctx_for_layer(i);
6909+
ggml_context * ctx_split = ctx_for_layer_split(i);
6910+
6911+
auto & layer = model.layers[i];
6912+
6913+
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
6914+
6915+
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
6916+
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
6917+
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
6918+
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
6919+
6920+
layer.attn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
6921+
layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
6922+
layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
6923+
6924+
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
6925+
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
6926+
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
6927+
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
6928+
layer.ffn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
6929+
}
6930+
} break;
68476931
case LLM_ARCH_STARCODER2:
68486932
{
68496933
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -11839,6 +11923,142 @@ struct llm_build_context {
1183911923
return gf;
1184011924
}
1184111925

11926+
struct ggml_cgraph * build_gemma3() {
11927+
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
11928+
11929+
const int64_t n_embd_head_k = hparams.n_embd_head_k;
11930+
11931+
struct ggml_tensor * cur;
11932+
struct ggml_tensor * inpL;
11933+
11934+
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
11935+
11936+
// TODO: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
11937+
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
11938+
cb(inpL, "inp_scaled", -1);
11939+
11940+
// inp_pos - contains the positions
11941+
struct ggml_tensor * inp_pos = build_inp_pos();
11942+
11943+
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true);
11944+
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
11945+
11946+
for (int il = 0; il < n_layer; ++il) {
11947+
const bool is_swa = il % 6 < 5;
11948+
11949+
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
11950+
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
11951+
11952+
struct ggml_tensor * KQ_mask_l = is_swa ? KQ_mask_swa : KQ_mask;
11953+
11954+
// norm
11955+
cur = llm_build_norm(ctx0, inpL, hparams,
11956+
model.layers[il].attn_norm, NULL,
11957+
LLM_NORM_RMS, cb, il);
11958+
cb(cur, "attn_norm", il);
11959+
11960+
// self-attention
11961+
{
11962+
// compute Q and K and RoPE them
11963+
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
11964+
cb(Qcur, "Qcur", il);
11965+
11966+
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
11967+
cb(Kcur, "Kcur", il);
11968+
11969+
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
11970+
cb(Vcur, "Vcur", il);
11971+
11972+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens);
11973+
Qcur = llm_build_norm(ctx0, Qcur, hparams,
11974+
model.layers[il].attn_q_norm, NULL,
11975+
LLM_NORM_RMS, cb, il);
11976+
cb(Qcur, "Qcur_normed", il);
11977+
11978+
Qcur = ggml_rope_ext(
11979+
ctx0, Qcur, inp_pos, nullptr,
11980+
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
11981+
ext_factor, attn_factor, beta_fast, beta_slow);
11982+
cb(Qcur, "Qcur", il);
11983+
11984+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens);
11985+
Kcur = llm_build_norm(ctx0, Kcur, hparams,
11986+
model.layers[il].attn_k_norm, NULL,
11987+
LLM_NORM_RMS, cb, il);
11988+
cb(Kcur, "Kcur_normed", il);
11989+
11990+
Kcur = ggml_rope_ext(
11991+
ctx0, Kcur, inp_pos, nullptr,
11992+
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
11993+
ext_factor, attn_factor, beta_fast, beta_slow);
11994+
cb(Kcur, "Kcur", il);
11995+
11996+
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
11997+
model.layers[il].wo, NULL,
11998+
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il);
11999+
}
12000+
12001+
cur = llm_build_norm(ctx0, cur, hparams,
12002+
model.layers[il].attn_post_norm, NULL,
12003+
LLM_NORM_RMS, cb, il);
12004+
cb(cur, "attn_post_norm", il);
12005+
12006+
if (il == n_layer - 1) {
12007+
// skip computing output for unused tokens
12008+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
12009+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
12010+
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
12011+
}
12012+
12013+
struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
12014+
cb(sa_out, "sa_out", il);
12015+
12016+
cur = llm_build_norm(ctx0, sa_out, hparams,
12017+
model.layers[il].ffn_norm, NULL,
12018+
LLM_NORM_RMS, cb, il);
12019+
cb(cur, "ffn_norm", il);
12020+
12021+
// feed-forward network
12022+
{
12023+
cur = llm_build_ffn(ctx0, lctx, cur,
12024+
model.layers[il].ffn_up, NULL, NULL,
12025+
model.layers[il].ffn_gate, NULL, NULL,
12026+
model.layers[il].ffn_down, NULL, NULL,
12027+
NULL,
12028+
LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
12029+
cb(cur, "ffn_out", il);
12030+
}
12031+
12032+
cur = llm_build_norm(ctx0, cur, hparams,
12033+
model.layers[il].ffn_post_norm, NULL,
12034+
LLM_NORM_RMS, cb, -1);
12035+
cb(cur, "ffn_post_norm", -1);
12036+
12037+
cur = ggml_add(ctx0, cur, sa_out);
12038+
cur = lctx.cvec.apply_to(ctx0, cur, il);
12039+
cb(cur, "l_out", il);
12040+
12041+
// input for next layer
12042+
inpL = cur;
12043+
}
12044+
12045+
cur = inpL;
12046+
12047+
cur = llm_build_norm(ctx0, cur, hparams,
12048+
model.output_norm, NULL,
12049+
LLM_NORM_RMS, cb, -1);
12050+
12051+
cb(cur, "result_norm", -1);
12052+
12053+
// lm_head
12054+
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
12055+
12056+
cb(cur, "result_output", -1);
12057+
12058+
ggml_build_forward_expand(gf, cur);
12059+
12060+
return gf;
12061+
}
1184212062

1184312063
struct ggml_cgraph * build_starcoder2() {
1184412064
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
@@ -14136,6 +14356,10 @@ static struct ggml_cgraph * llama_build_graph(
1413614356
{
1413714357
result = llm.build_gemma2();
1413814358
} break;
14359+
case LLM_ARCH_GEMMA3:
14360+
{
14361+
result = llm.build_gemma3();
14362+
} break;
1413914363
case LLM_ARCH_STARCODER2:
1414014364
{
1414114365
result = llm.build_starcoder2();
@@ -17315,6 +17539,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1731517539
case LLM_ARCH_PHI3:
1731617540
case LLM_ARCH_GEMMA:
1731717541
case LLM_ARCH_GEMMA2:
17542+
case LLM_ARCH_GEMMA3:
1731817543
case LLM_ARCH_STARCODER2:
1731917544
case LLM_ARCH_OPENELM:
1732017545
case LLM_ARCH_GPTNEOX:

0 commit comments

Comments
 (0)