Skip to content

Commit fb15d64

Browse files
authored
llama : add support for EmbeddingGemma 300m (#15798)
This commit add support for the EmbeddingGemma 300m. This model supports sliding window attention (SWA) and a new swq_type is introduced to support symmetric SWA masking. This commit also extracts the code from the function llama_is_masked_swa in llama-impl.h, so that the logic can be shared by both llm_graph_input_attn_no_cache::set_input and llama_kv_cache::set_input_kq_mask. With this commit the EmbeddingGemma 300m model can be converted to to GGUF and used with llama.cpp. Once the model has been uploaded to HuggingFace it can be used like this: ```console ./build/bin/llama-cli -hf ggml-org/embeddinggemma-300m-GGUF:Q8_0 ```
1 parent 856ed09 commit fb15d64

15 files changed

+328
-47
lines changed

convert_hf_to_gguf.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5122,6 +5122,15 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
51225122
return [(self.map_tensor_name(name), data_torch)]
51235123

51245124

5125+
@ModelBase.register("Gemma3TextModel")
5126+
class EmbeddingGemma(Gemma3Model):
5127+
model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
5128+
5129+
def set_gguf_parameters(self):
5130+
super().set_gguf_parameters()
5131+
self._try_set_pooling_type()
5132+
5133+
51255134
@ModelBase.register("Gemma3ForConditionalGeneration")
51265135
class Gemma3VisionModel(MmprojModel):
51275136
def set_gguf_parameters(self):

gguf-py/gguf/constants.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ class MODEL_ARCH(IntEnum):
340340
GEMMA2 = auto()
341341
GEMMA3 = auto()
342342
GEMMA3N = auto()
343+
GEMMA_EMBEDDING = auto()
343344
STARCODER2 = auto()
344345
RWKV6 = auto()
345346
RWKV6QWEN2 = auto()
@@ -674,6 +675,7 @@ class MODEL_TENSOR(IntEnum):
674675
MODEL_ARCH.GEMMA2: "gemma2",
675676
MODEL_ARCH.GEMMA3: "gemma3",
676677
MODEL_ARCH.GEMMA3N: "gemma3n",
678+
MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding",
677679
MODEL_ARCH.STARCODER2: "starcoder2",
678680
MODEL_ARCH.RWKV6: "rwkv6",
679681
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
@@ -1719,6 +1721,24 @@ class MODEL_TENSOR(IntEnum):
17191721
MODEL_TENSOR.LAUREL_R,
17201722
MODEL_TENSOR.LAUREL_POST_NORM,
17211723
],
1724+
MODEL_ARCH.GEMMA_EMBEDDING: [
1725+
MODEL_TENSOR.TOKEN_EMBD,
1726+
MODEL_TENSOR.OUTPUT,
1727+
MODEL_TENSOR.OUTPUT_NORM,
1728+
MODEL_TENSOR.ATTN_Q,
1729+
MODEL_TENSOR.ATTN_Q_NORM,
1730+
MODEL_TENSOR.ATTN_K,
1731+
MODEL_TENSOR.ATTN_K_NORM,
1732+
MODEL_TENSOR.ATTN_V,
1733+
MODEL_TENSOR.ATTN_OUT,
1734+
MODEL_TENSOR.FFN_GATE,
1735+
MODEL_TENSOR.FFN_DOWN,
1736+
MODEL_TENSOR.FFN_UP,
1737+
MODEL_TENSOR.ATTN_NORM,
1738+
MODEL_TENSOR.ATTN_POST_NORM,
1739+
MODEL_TENSOR.FFN_PRE_NORM,
1740+
MODEL_TENSOR.FFN_POST_NORM,
1741+
],
17221742
MODEL_ARCH.STARCODER2: [
17231743
MODEL_TENSOR.TOKEN_EMBD,
17241744
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class TensorNameMap:
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
1616
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 plamo2 granite-hybrid
17+
"embed_tokens", # embeddinggemma
1718
"tok_embeddings", # llama-pth
1819
"embeddings.word_embeddings", # bert nomic-bert
1920
"language_model.embedding.word_embeddings", # persimmon
@@ -141,6 +142,7 @@ class TensorNameMap:
141142
"rwkv.blocks.{bid}.ln1", # rwkv6
142143
"model.layers.{bid}.ln1", # rwkv7
143144
"model.layers.{bid}.input_layernorm", # llama4
145+
"layers.{bid}.input_layernorm", # embeddinggemma
144146
"transformer_encoder.{bid}.attention_norm", # neobert
145147
"model.layers.{bid}.operator_norm", # lfm2
146148
"model.transformer.blocks.{bid}.attn_norm", # llada
@@ -179,6 +181,7 @@ class TensorNameMap:
179181
# Attention query
180182
MODEL_TENSOR.ATTN_Q: (
181183
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe
184+
"layers.{bid}.self_attn.q_proj", # embeddinggemma
182185
"model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom
183186
"layers.{bid}.attention.wq", # llama-pth
184187
"encoder.layer.{bid}.attention.self.query", # bert
@@ -197,6 +200,7 @@ class TensorNameMap:
197200
# Attention key
198201
MODEL_TENSOR.ATTN_K: (
199202
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe
203+
"layers.{bid}.self_attn.k_proj", # embeddinggemma
200204
"model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom
201205
"layers.{bid}.attention.wk", # llama-pth
202206
"encoder.layer.{bid}.attention.self.key", # bert
@@ -216,6 +220,7 @@ class TensorNameMap:
216220
# Attention value
217221
MODEL_TENSOR.ATTN_V: (
218222
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe
223+
"layers.{bid}.self_attn.v_proj", # embeddinggemma
219224
"layers.{bid}.attention.wv", # llama-pth
220225
"encoder.layer.{bid}.attention.self.value", # bert
221226
"transformer.layer.{bid}.attention.v_lin", # distillbert
@@ -239,6 +244,7 @@ class TensorNameMap:
239244
"transformer.h.{bid}.self_attention.dense", # falcon
240245
"h.{bid}.self_attention.dense", # bloom
241246
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe
247+
"layers.{bid}.self_attn.o_proj", # embeddinggemma
242248
"model.layers.{bid}.self_attn.out_proj", # lfm2
243249
"model.layers.{bid}.self_attn.linear_attn", # deci
244250
"layers.{bid}.attention.wo", # llama-pth
@@ -277,6 +283,7 @@ class TensorNameMap:
277283

278284
MODEL_TENSOR.ATTN_POST_NORM: (
279285
"model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge
286+
"layers.{bid}.post_attention_layernorm", # embeddinggemma
280287
"model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414
281288
"model.layers.layers.{bid}.post_mixer_norm.weight", # plamo2
282289
),
@@ -320,12 +327,14 @@ class TensorNameMap:
320327
# Post feed-forward norm
321328
MODEL_TENSOR.FFN_PRE_NORM: (
322329
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
330+
"layers.{bid}.pre_feedforward_layernorm", # embeddinggemma
323331
"model.layers.{bid}.pre_ff_layernorm.weight",
324332
),
325333

326334
# Post feed-forward norm
327335
MODEL_TENSOR.FFN_POST_NORM: (
328336
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
337+
"layers.{bid}.post_feedforward_layernorm", # embeddinggemma
329338
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
330339
"model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
331340
"model.layers.{bid}.feed_forward.up_proj",
@@ -362,6 +371,7 @@ class TensorNameMap:
362371
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
363372
"h.{bid}.mlp.dense_h_to_4h", # bloom
364373
"model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2
374+
"layers.{bid}.mlp.up_proj", # embeddinggemma
365375
"layers.{bid}.feed_forward.w3", # llama-pth
366376
"encoder.layer.{bid}.intermediate.dense", # bert
367377
"transformer.layer.{bid}.ffn.lin1", # distillbert
@@ -421,6 +431,7 @@ class TensorNameMap:
421431
# Feed-forward gate
422432
MODEL_TENSOR.FFN_GATE: (
423433
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2
434+
"layers.{bid}.mlp.gate_proj", # embeddinggemma
424435
"layers.{bid}.feed_forward.w1", # llama-pth
425436
"transformer.h.{bid}.mlp.w2", # qwen
426437
"transformer.h.{bid}.mlp.c_fc2", # jais
@@ -461,6 +472,7 @@ class TensorNameMap:
461472
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
462473
"h.{bid}.mlp.dense_4h_to_h", # bloom
463474
"model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2
475+
"layers.{bid}.mlp.down_proj", # embeddinggemma
464476
"layers.{bid}.feed_forward.w2", # llama-pth
465477
"encoder.layer.{bid}.output.dense", # bert
466478
"transformer.layer.{bid}.ffn.lin2", # distillbert
@@ -513,6 +525,7 @@ class TensorNameMap:
513525
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
514526
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
515527
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
528+
"layers.{bid}.self_attn.q_norm", # embeddinggemma
516529
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
517530
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
518531
"transformer.layers.{bid}.attn.q_norm", # openelm
@@ -525,6 +538,7 @@ class TensorNameMap:
525538
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
526539
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
527540
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
541+
"layers.{bid}.self_attn.k_norm", # embeddinggemma
528542
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
529543
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
530544
"transformer.layers.{bid}.attn.k_norm", # openelm

src/llama-arch.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
4545
{ LLM_ARCH_GEMMA2, "gemma2" },
4646
{ LLM_ARCH_GEMMA3, "gemma3" },
4747
{ LLM_ARCH_GEMMA3N, "gemma3n" },
48+
{ LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" },
4849
{ LLM_ARCH_STARCODER2, "starcoder2" },
4950
{ LLM_ARCH_MAMBA, "mamba" },
5051
{ LLM_ARCH_MAMBA2, "mamba2" },
@@ -1038,6 +1039,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
10381039
{ LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" },
10391040
},
10401041
},
1042+
{
1043+
LLM_ARCH_GEMMA_EMBEDDING,
1044+
{
1045+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1046+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1047+
{ LLM_TENSOR_OUTPUT, "output" },
1048+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1049+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1050+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1051+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1052+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1053+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1054+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1055+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
1056+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1057+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1058+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1059+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1060+
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
1061+
},
1062+
},
10411063
{
10421064
LLM_ARCH_STARCODER2,
10431065
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ enum llm_arch {
4949
LLM_ARCH_GEMMA2,
5050
LLM_ARCH_GEMMA3,
5151
LLM_ARCH_GEMMA3N,
52+
LLM_ARCH_GEMMA_EMBEDDING,
5253
LLM_ARCH_STARCODER2,
5354
LLM_ARCH_MAMBA,
5455
LLM_ARCH_MAMBA2,

src/llama-graph.cpp

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
258258
}
259259
}
260260

261+
static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
262+
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
263+
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
264+
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
265+
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
266+
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
267+
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
268+
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
269+
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
270+
271+
LLAMA_LOG_DEBUG(" ");
272+
for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
273+
LLAMA_LOG_DEBUG("%2d", j);
274+
}
275+
LLAMA_LOG_DEBUG("\n");
276+
277+
for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
278+
LLAMA_LOG_DEBUG(" %2d ", i);
279+
for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
280+
float val = data[i * n_kv + j];
281+
if (val == -INFINITY) {
282+
LLAMA_LOG_DEBUG("");
283+
} else {
284+
LLAMA_LOG_DEBUG(" 0");
285+
}
286+
}
287+
LLAMA_LOG_DEBUG("\n");
288+
}
289+
}
290+
261291
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
262292
const int64_t n_kv = ubatch->n_tokens;
263293
const int64_t n_tokens = ubatch->n_tokens;
@@ -277,21 +307,32 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
277307
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
278308
const llama_seq_id s0 = ubatch->seq_id[i0][0];
279309

280-
// TODO: reimplement this like in llama_kv_cache
281-
if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
282-
if (hparams.use_alibi) {
283-
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
284-
} else {
285-
f = 0.0f;
286-
}
287-
break;
310+
if (s0 != s1) {
311+
continue; // skip different sequences
288312
}
289-
}
290313

314+
if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
315+
continue; // skip future tokens for causal attention
316+
}
317+
318+
if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
319+
continue; // skip masked tokens for SWA
320+
}
321+
322+
// TODO: reimplement this like in llama_kv_cache_unified
323+
if (hparams.use_alibi) {
324+
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
325+
} else {
326+
f = 0.0f;
327+
}
328+
}
291329
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
292330
}
293331
}
294332
}
333+
if (debug) {
334+
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
335+
}
295336
}
296337

297338
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {

src/llama-graph.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ struct llm_graph_params;
7878

7979
class llm_graph_input_i {
8080
public:
81+
llm_graph_input_i() {
82+
const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
83+
debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
84+
}
85+
8186
virtual ~llm_graph_input_i() = default;
8287

8388
virtual void set_input(const llama_ubatch * ubatch) = 0;
@@ -90,6 +95,9 @@ class llm_graph_input_i {
9095
GGML_UNUSED(params);
9196
return false;
9297
}
98+
protected:
99+
// env: LLAMA_GRAPH_INPUT_DEBUG
100+
int debug = 0;
93101
};
94102

95103
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;

src/llama-hparams.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "llama-hparams.h"
22

33
#include "ggml.h"
4+
#include <cassert>
45

56
void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
67
if (dense_first) {
@@ -178,3 +179,39 @@ uint32_t llama_hparams::n_layer_kv() const {
178179

179180
return res;
180181
}
182+
183+
bool llama_hparams::is_masked_swa(llama_pos p0, llama_pos p1) const {
184+
assert(p0 >= 0 && p1 >= 0);
185+
186+
switch (swa_type) {
187+
case LLAMA_SWA_TYPE_NONE:
188+
{
189+
} break;
190+
case LLAMA_SWA_TYPE_STANDARD:
191+
{
192+
if (p1 - p0 >= (int32_t) n_swa) {
193+
return true;
194+
}
195+
} break;
196+
case LLAMA_SWA_TYPE_CHUNKED:
197+
{
198+
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
199+
200+
if (p0 < pos_chunk_start) {
201+
return true;
202+
}
203+
} break;
204+
case LLAMA_SWA_TYPE_SYMMETRIC:
205+
{
206+
const int32_t half_n_swa = (int32_t) n_swa / 2;
207+
const int32_t pos_diff = p1 - p0;
208+
209+
// Mask if outside the symmetric window
210+
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
211+
return true;
212+
}
213+
} break;
214+
}
215+
216+
return false;
217+
}

src/llama-hparams.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ enum llama_expert_gating_func_type {
1616
};
1717

1818
enum llama_swa_type {
19-
LLAMA_SWA_TYPE_NONE = 0,
20-
LLAMA_SWA_TYPE_STANDARD = 1,
21-
LLAMA_SWA_TYPE_CHUNKED = 2,
19+
LLAMA_SWA_TYPE_NONE = 0,
20+
LLAMA_SWA_TYPE_STANDARD = 1,
21+
LLAMA_SWA_TYPE_CHUNKED = 2,
22+
LLAMA_SWA_TYPE_SYMMETRIC = 3,
2223
};
2324

2425
struct llama_hparams_posnet {
@@ -227,6 +228,8 @@ struct llama_hparams {
227228

228229
// number of layers for which has_kv() returns true
229230
uint32_t n_layer_kv() const;
231+
232+
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
230233
};
231234

232235
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");

0 commit comments

Comments
 (0)