Skip to content

Commit f3be74e

Browse files
committed
model: add support for EmbeddingGemma SentenceTransformers dense linear projections
- converting model with dense-layers is optional - introduced dense config params
1 parent 8ceff26 commit f3be74e

File tree

9 files changed

+128
-37
lines changed

9 files changed

+128
-37
lines changed

convert_hf_to_gguf.py

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,15 @@ class ModelBase:
9393
# Mistral format specifics
9494
is_mistral_format: bool = False
9595
disable_mistral_community_chat_template: bool = False
96+
sentence_transformers_dense_modules: bool = False
9697

9798
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
9899
use_temp_file: bool = False, eager: bool = False,
99100
metadata_override: Path | None = None, model_name: str | None = None,
100101
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
101102
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
102-
disable_mistral_community_chat_template: bool = False):
103+
disable_mistral_community_chat_template: bool = False,
104+
sentence_transformers_dense_modules: bool = False):
103105
if type(self) is ModelBase or \
104106
type(self) is TextModel or \
105107
type(self) is MmprojModel:
@@ -114,6 +116,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
114116
self.lazy = not eager or (remote_hf_model_id is not None)
115117
self.dry_run = dry_run
116118
self.remote_hf_model_id = remote_hf_model_id
119+
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
117120
if remote_hf_model_id is not None:
118121
self.is_safetensors = True
119122

@@ -5256,37 +5259,53 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
52565259
class EmbeddingGemma(Gemma3Model):
52575260
model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
52585261
module_paths = []
5259-
dense_tensors = []
5262+
dense_features_dims = {}
52605263

52615264
def __init__(self, *args, **kwargs):
52625265
super().__init__(*args, **kwargs)
5263-
# read molues.json to determine if model has Dense layers
5264-
module_path = self.dir_model / "modules.json"
5265-
if module_path.is_file():
5266-
with open(module_path, encoding="utf-8") as f:
5267-
modules = json.load(f)
5268-
for mod in modules:
5269-
if mod["type"] == "sentence_transformers.models.Dense":
5270-
module_path = mod["path"]
5271-
tensors_file = self.dir_model / module_path / "model.safetensors"
5272-
if tensors_file.is_file():
5273-
self.module_paths.append(module_path)
5274-
5266+
if self.sentence_transformers_dense_modules:
5267+
# read molues.json to determine if model has Dense layers
5268+
modules_file = self.dir_model / "modules.json"
5269+
if modules_file.is_file():
5270+
with open(modules_file, encoding="utf-8") as modules_json_file:
5271+
mods = json.load(modules_json_file)
5272+
for mod in mods:
5273+
if mod["type"] == "sentence_transformers.models.Dense":
5274+
mod_path = mod["path"]
5275+
# check if model.safetensors file for Dense layer exists
5276+
model_tensors_file = self.dir_model / mod_path / "model.safetensors"
5277+
if model_tensors_file.is_file():
5278+
self.module_paths.append(mod_path)
5279+
# read config.json of the Dense layer to get in/out features
5280+
mod_conf_file = self.dir_model / mod_path / "config.json"
5281+
if mod_conf_file.is_file():
5282+
with open(mod_conf_file, encoding="utf-8") as mod_conf_json_file:
5283+
mod_conf = json.load(mod_conf_json_file)
5284+
# hparams dense_2_feat_out and dense_3_feat_in are required when loading model's dense weights
5285+
prefix = self._get_dense_prefix(mod_path)
5286+
if (mod_conf["in_features"] is not None
5287+
and mod_conf["out_features"] is not None):
5288+
self.dense_features_dims[prefix] = (mod_conf["in_features"], mod_conf["out_features"])
52755289

52765290
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
52775291
from safetensors.torch import load_file
52785292
module_paths = list(self.module_paths)
52795293
for i, module_path in enumerate(module_paths):
52805294
tensors_file = self.dir_model / module_path / "model.safetensors"
52815295
local_tensors = load_file(tensors_file)
5282-
tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3"
5296+
tensor_name = self._get_dense_prefix(module_path)
52835297
for name, local_tensor in local_tensors.items():
52845298
if not name.endswith(".weight"):
52855299
continue
52865300
orig_name = name.replace("linear", tensor_name)
52875301
name = self.map_tensor_name(orig_name)
52885302
yield name, local_tensor.clone()
52895303

5304+
@staticmethod
5305+
def _get_dense_prefix(module_path) -> str:
5306+
"""Get the tensor name prefix for the Dense layer from module path."""
5307+
tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3"
5308+
return tensor_name
52905309

52915310
def set_gguf_parameters(self):
52925311
super().set_gguf_parameters()
@@ -5303,6 +5322,11 @@ def set_gguf_parameters(self):
53035322
logger.info(f"Using original sliding_window from config: {orig_sliding_window} "
53045323
f"instead of {self.hparams['sliding_window']}")
53055324
self.gguf_writer.add_sliding_window(orig_sliding_window)
5325+
if self.sentence_transformers_dense_modules:
5326+
for dense, dims in self.dense_features_dims.items():
5327+
logger.info(f"Setting dense layer {dense} in/out features to {dims}")
5328+
self.gguf_writer.add_dense_features_dims(dense, dims[0], dims[1])
5329+
self.gguf_writer.add_pooling_type_opt(False)
53065330

53075331
self._try_set_pooling_type()
53085332

@@ -9247,6 +9271,13 @@ def parse_args() -> argparse.Namespace:
92479271
)
92489272
)
92499273

9274+
parser.add_argument(
9275+
"--sentence-transformers-dense-modules", action="store_true",
9276+
help=("Whether to include sentence-transformers dense modules."
9277+
"It can be used for sentence-transformers models, like google/embeddinggemma-300m"
9278+
"Default these modules are not included.")
9279+
)
9280+
92509281
args = parser.parse_args()
92519282
if not args.print_supported_models and args.model is None:
92529283
parser.error("the following arguments are required: model")
@@ -9309,9 +9340,13 @@ def main() -> None:
93099340
if args.remote:
93109341
hf_repo_id = args.model
93119342
from huggingface_hub import snapshot_download
9343+
allowed_patterns = ["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]
9344+
if args.sentence_transformers_dense_modules:
9345+
# include sentence-transformers dense modules safetensors files
9346+
allowed_patterns.append("*.safetensors")
93129347
local_dir = snapshot_download(
93139348
repo_id=hf_repo_id,
9314-
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
9349+
allow_patterns=allowed_patterns)
93159350
dir_model = Path(local_dir)
93169351
logger.info(f"Downloaded config and tokenizer to {local_dir}")
93179352
else:
@@ -9379,7 +9414,8 @@ def main() -> None:
93799414
split_max_tensors=args.split_max_tensors,
93809415
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
93819416
small_first_shard=args.no_tensor_first_split,
9382-
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template
9417+
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
9418+
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
93839419
)
93849420

93859421
if args.vocab_only:

gguf-py/gguf/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ class LLM:
128128
ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx"
129129
ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs"
130130
EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input"
131+
DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in"
132+
DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out"
133+
POOLING_TYPE_OPT = "{arch}.pooling_type_opt"
131134

132135
class Attention:
133136
HEAD_COUNT = "{arch}.attention.head_count"

gguf-py/gguf/gguf_writer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,13 @@ def add_shared_kv_layers(self, value: int) -> None:
730730
def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
731731
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
732732

733+
def add_dense_features_dims(self, dense:str, in_f:int, out_f:int) -> None:
734+
self.add_uint32(Keys.LLM.DENSE_FEAT_IN_SIZE.format(arch=self.arch, dense=dense), in_f)
735+
self.add_uint32(Keys.LLM.DENSE_FEAT_OUT_SIZE.format(arch=self.arch, dense=dense), out_f)
736+
737+
def add_pooling_type_opt(self, enable: bool) -> None:
738+
self.add_bool(Keys.LLM.POOLING_TYPE_OPT.format(arch=self.arch), enable)
739+
733740
def add_logit_scale(self, value: float) -> None:
734741
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
735742

src/llama-arch.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,12 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
217217
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
218218

219219
{ LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
220+
// sentence-transformers dense modules feature dims
221+
{ LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" },
222+
{ LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" },
223+
{ LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" },
224+
{ LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" },
225+
{ LLM_KV_POOLING_TYPE_OPT, "%s.pooling_type_opt" },
220226

221227
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
222228
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },

src/llama-arch.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,13 @@ enum llm_kv {
264264
LLM_KV_TOKENIZER_PREFIX_ID,
265265
LLM_KV_TOKENIZER_SUFFIX_ID,
266266
LLM_KV_TOKENIZER_MIDDLE_ID,
267+
268+
// sentence-transformers dense layers in and out features
269+
LLM_KV_DENSE_2_FEAT_IN,
270+
LLM_KV_DENSE_2_FEAT_OUT,
271+
LLM_KV_DENSE_3_FEAT_IN,
272+
LLM_KV_DENSE_3_FEAT_OUT,
273+
LLM_KV_POOLING_TYPE_OPT,
267274
};
268275

269276
enum llm_tensor {

src/llama-context.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2346,6 +2346,15 @@ llama_context * llama_init_from_model(
23462346
return nullptr;
23472347
}
23482348

2349+
// if setting pooling_type is disabled, set it to model default
2350+
// for sentence-transformers models (e.g. EmbeddingGemma) mean-pooling is required
2351+
// when dense layers are enabled
2352+
if (!model->hparams.pooling_type_opt) {
2353+
params.pooling_type = model->hparams.pooling_type;
2354+
LLAMA_LOG_INFO("%s: setting pooling_type to models default: %d\n", __func__, params.pooling_type);
2355+
2356+
}
2357+
23492358
try {
23502359
auto * ctx = new llama_context(*model, params);
23512360
return ctx;

src/llama-graph.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,11 +1856,15 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
18561856
void llm_graph_context::build_dense_out(
18571857
ggml_tensor *dense_2,
18581858
ggml_tensor *dense_3) const {
1859-
ggml_tensor * cur = res->get_embd_pooled();
1860-
cur = ggml_mul_mat(ctx0, dense_2, cur);
1861-
cb(cur, "result_embd_pooled", -1);
1862-
cur = ggml_mul_mat(ctx0, dense_3, cur);
1863-
cb(cur, "result_embd_pooled", -1);
1859+
ggml_tensor *cur = res->get_embd_pooled();
1860+
if (dense_2 != nullptr) {
1861+
cur = ggml_mul_mat(ctx0, dense_2, cur);
1862+
cb(cur, "result_embd_pooled", -1);
1863+
}
1864+
if (dense_3 != nullptr) {
1865+
cur = ggml_mul_mat(ctx0, dense_3, cur);
1866+
cb(cur, "result_embd_pooled", -1);
1867+
}
18641868
res->t_embd_pooled = cur;
18651869
ggml_build_forward_expand(gf, cur);
18661870
}

src/llama-hparams.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,15 @@ struct llama_hparams {
169169
uint32_t laurel_rank = 64;
170170
uint32_t n_embd_altup = 256;
171171

172+
// needed for sentence-transformers dense layers
173+
uint32_t dense_2_feat_in = 0; // in_features of the 2_Dense
174+
uint32_t dense_2_feat_out = 0; // out_features of the 2_Dense
175+
uint32_t dense_3_feat_in = 0; // in_features of the 3_Dense
176+
uint32_t dense_3_feat_out = 0; // out_features of the 3_Dense
177+
178+
// whether pooling_type can be overridden by user
179+
bool pooling_type_opt = true;
180+
172181
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
173182
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
174183
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;

src/llama-model.cpp

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,20 +1207,28 @@ void llama_model::load_hparams(llama_model_loader & ml) {
12071207
hparams.set_swa_pattern(6);
12081208

12091209
hparams.causal_attn = false; // embeddings do not use causal attention
1210-
hparams.rope_freq_base_train_swa = 10000.0f;
1210+
hparams.rope_freq_base_train_swa = 10000.0f;
12111211
hparams.rope_freq_scale_train_swa = 1.0f;
12121212

1213-
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
1213+
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
12141214
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1215-
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
1215+
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
12161216

1217-
switch (hparams.n_layer) {
1218-
case 24: type = LLM_TYPE_0_3B; break;
1219-
default: type = LLM_TYPE_UNKNOWN;
1220-
}
1221-
hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k));
1217+
//applied only if model converted with --sentence-transformers-dense-modules
1218+
ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false);
1219+
ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false);
1220+
ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false);
1221+
ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false);
1222+
ml.get_key(LLM_KV_POOLING_TYPE_OPT, hparams.pooling_type_opt, false);
12221223

1223-
} break;
1224+
1225+
switch (hparams.n_layer) {
1226+
case 24: type = LLM_TYPE_0_3B; break;
1227+
default: type = LLM_TYPE_UNKNOWN;
1228+
}
1229+
hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k));
1230+
}
1231+
break;
12241232
case LLM_ARCH_STARCODER2:
12251233
{
12261234
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -3646,8 +3654,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
36463654
}
36473655

36483656
// Dense linear weights
3649-
dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, 4 * n_embd}, 0);
3650-
dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {4 * n_embd, n_embd}, 0);
3657+
dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED);
3658+
dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED);
36513659

36523660

36533661
for (int i = 0; i < n_layer; ++i) {
@@ -19633,10 +19641,12 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1963319641

1963419642
// add on pooling layer
1963519643
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
19636-
// embeddinggemma specific
19637-
//sentence-transformer dense linear projections are applied after pooling
19638-
if (llm->arch == LLM_ARCH_GEMMA_EMBEDDING) {
19639-
llm->build_dense_out(dense_2_out_layers,dense_3_out_layers);
19644+
19645+
// if the gguf model was converted with --sentence-transformers-dense-modules
19646+
// there will be two additional dense projection layers
19647+
// dense linear projections are applied after pooling
19648+
if (dense_2_out_layers != nullptr || dense_3_out_layers != nullptr) {
19649+
llm->build_dense_out(dense_2_out_layers, dense_3_out_layers);
1964019650
}
1964119651

1964219652
return llm->res->get_gf();

0 commit comments

Comments
 (0)