Skip to content

Commit 84ab83c

Browse files
authored
model : jina-embeddings-v3 support (ggml-org#13693)
* initial jina-embeddings-v3 support * initial jina-embeddings-v3 support * initial jina-embeddings-v3 support * fix vocab parsing with only tokenizer.json * set mask token lstrip attribute * additional unk_token_id fallback just in case [no ci] * revert vocab_size() change [no ci] * merge tensor loading into general bert * rope * add lora embedding and loading (non-functional) * export separate lora ggufs instead * add adapter metadata api * use std::string * convert_hf_to_lora compatibility * fix assert * apply suggestions from review * apply suggestion from review
1 parent 55042b3 commit 84ab83c

File tree

14 files changed

+246
-24
lines changed

14 files changed

+246
-24
lines changed

common/arg.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2555,15 +2555,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
25552555
{"--lora"}, "FNAME",
25562556
"path to LoRA adapter (can be repeated to use multiple adapters)",
25572557
[](common_params & params, const std::string & value) {
2558-
params.lora_adapters.push_back({ std::string(value), 1.0, nullptr });
2558+
params.lora_adapters.push_back({ std::string(value), 1.0, "", "", nullptr });
25592559
}
25602560
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
25612561
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
25622562
add_opt(common_arg(
25632563
{"--lora-scaled"}, "FNAME", "SCALE",
25642564
"path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
25652565
[](common_params & params, const std::string & fname, const std::string & scale) {
2566-
params.lora_adapters.push_back({ fname, std::stof(scale), nullptr });
2566+
params.lora_adapters.push_back({ fname, std::stof(scale), "", "", nullptr });
25672567
}
25682568
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
25692569
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));

common/common.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,12 @@ struct common_init_result common_init_from_params(common_params & params) {
988988
return iparams;
989989
}
990990

991+
char buf[1024];
991992
la.ptr = lora.get();
993+
llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf));
994+
la.task_name = buf;
995+
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
996+
la.prompt_prefix = buf;
992997
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
993998
}
994999

common/common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ struct common_adapter_lora_info {
3434
std::string path;
3535
float scale;
3636

37+
std::string task_name;
38+
std::string prompt_prefix;
39+
3740
struct llama_adapter_lora * ptr;
3841
};
3942

convert_hf_to_gguf.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class ModelBase:
7272
endianess: gguf.GGUFEndian
7373
use_temp_file: bool
7474
lazy: bool
75+
dry_run: bool
7576
part_names: list[str]
7677
is_safetensors: bool
7778
hparams: dict[str, Any]
@@ -111,6 +112,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
111112
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
112113
self.use_temp_file = use_temp_file
113114
self.lazy = not eager or (remote_hf_model_id is not None)
115+
self.dry_run = dry_run
114116
self.remote_hf_model_id = remote_hf_model_id
115117
if remote_hf_model_id is not None:
116118
self.is_safetensors = True
@@ -4871,11 +4873,35 @@ def modify_tensors(self, data_torch, name, bid):
48714873
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
48724874
class XLMRobertaModel(BertModel):
48734875
model_arch = gguf.MODEL_ARCH.BERT
4876+
_lora_files = {}
4877+
_lora_names = []
48744878

4875-
def __init__(self, *args, **kwargs):
4876-
super().__init__(*args, **kwargs)
4879+
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
4880+
hparams = kwargs.pop("hparams", None)
4881+
if hparams is None:
4882+
hparams = ModelBase.load_hparams(dir_model, False)
4883+
4884+
if lora_names := hparams.get("lora_adaptations"):
4885+
self._lora_names = lora_names
4886+
self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3
4887+
4888+
super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
48774889
self._xlmroberta_tokenizer_init()
48784890

4891+
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
4892+
if self._lora_names:
4893+
for name in self._lora_names:
4894+
fname = self.add_prefix_to_filename(self.fname_out, f"lora-{name}-")
4895+
self._lora_files[name] = gguf.GGUFWriter(fname, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, dry_run=self.dry_run)
4896+
4897+
return super().generate_extra_tensors()
4898+
4899+
def set_type(self):
4900+
for lora_writer in self._lora_files.values():
4901+
lora_writer.add_type(gguf.GGUFType.ADAPTER)
4902+
lora_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
4903+
super().set_type()
4904+
48794905
def set_vocab(self):
48804906
self._xlmroberta_set_vocab()
48814907

@@ -4885,13 +4911,62 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
48854911
if name.startswith("roberta."):
48864912
name = name[8:]
48874913

4914+
# jina-embeddings-v3
4915+
if ".parametrizations." in name:
4916+
name = name.replace(".parametrizations.", ".")
4917+
if name.endswith(".original"):
4918+
name = name[:-9]
4919+
48884920
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
48894921
if name == "embeddings.position_embeddings.weight":
48904922
if self._position_offset is not None:
48914923
data_torch = data_torch[self._position_offset:,:]
48924924

4925+
if name.endswith(".0.lora_A") or name.endswith(".0.lora_B"):
4926+
if name.startswith("pooler.dense"):
4927+
return []
4928+
4929+
num_loras = data_torch.size(0)
4930+
assert num_loras == len(self._lora_names)
4931+
4932+
# Split out each LoRA in their own GGUF
4933+
for i, lora_writer in enumerate(self._lora_files.values()):
4934+
new_name = self.map_tensor_name(name[:-9]) + name[-7:].lower()
4935+
data = data_torch[i, :, :]
4936+
# Transpose/flip token_embd/types into correct shape
4937+
if new_name == "token_embd.weight.lora_b":
4938+
data = data.T
4939+
elif new_name.startswith("token_types.weight."):
4940+
new_name = new_name[:-1] + ("a" if new_name[-1:] == "b" else "b")
4941+
lora_writer.add_tensor(new_name, data.float().numpy(), raw_dtype=gguf.GGMLQuantizationType.F32)
4942+
4943+
return []
4944+
48934945
return super().modify_tensors(data_torch, name, bid)
48944946

4947+
def set_gguf_parameters(self):
4948+
super().set_gguf_parameters()
4949+
4950+
# jina-embeddings-v3
4951+
if rotary_emb_base := self.hparams.get("rotary_emb_base"):
4952+
self.gguf_writer.add_rope_freq_base(rotary_emb_base)
4953+
lora_alpha = self.hparams.get("lora_alpha")
4954+
if lora_prompt_prefixes := self.hparams.get("task_instructions"):
4955+
assert self._lora_files and all(lora_name in lora_prompt_prefixes for lora_name in self._lora_files.keys())
4956+
for lora_name, lora_writer in self._lora_files.items():
4957+
lora_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, lora_alpha if lora_alpha is not None else 1.0)
4958+
lora_writer.add_string(gguf.Keys.Adapter.LORA_TASK_NAME, lora_name)
4959+
if lora_prompt_prefixes:
4960+
lora_writer.add_string(gguf.Keys.Adapter.LORA_PROMPT_PREFIX, lora_prompt_prefixes[lora_name])
4961+
4962+
def write(self):
4963+
super().write()
4964+
for lora_writer in self._lora_files.values():
4965+
lora_writer.write_header_to_file()
4966+
lora_writer.write_kv_data_to_file()
4967+
lora_writer.write_tensors_to_file(progress=True)
4968+
lora_writer.close()
4969+
48954970

48964971
@ModelBase.register("GemmaForCausalLM")
48974972
class GemmaModel(TextModel):

gguf-py/gguf/constants.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,10 @@ class Tokenizer:
231231
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
232232

233233
class Adapter:
234-
TYPE = "adapter.type"
235-
LORA_ALPHA = "adapter.lora.alpha"
234+
TYPE = "adapter.type"
235+
LORA_ALPHA = "adapter.lora.alpha"
236+
LORA_TASK_NAME = "adapter.lora.task_name"
237+
LORA_PROMPT_PREFIX = "adapter.lora.prompt_prefix"
236238

237239
class IMatrix:
238240
CHUNK_COUNT = "imatrix.chunk_count"
@@ -315,6 +317,7 @@ class MODEL_ARCH(IntEnum):
315317
NOMIC_BERT_MOE = auto()
316318
NEO_BERT = auto()
317319
JINA_BERT_V2 = auto()
320+
JINA_BERT_V3 = auto()
318321
BLOOM = auto()
319322
STABLELM = auto()
320323
QWEN = auto()
@@ -647,6 +650,7 @@ class MODEL_TENSOR(IntEnum):
647650
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
648651
MODEL_ARCH.NEO_BERT: "neo-bert",
649652
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
653+
MODEL_ARCH.JINA_BERT_V3: "jina-bert-v3",
650654
MODEL_ARCH.BLOOM: "bloom",
651655
MODEL_ARCH.STABLELM: "stablelm",
652656
MODEL_ARCH.QWEN: "qwen",
@@ -1234,6 +1238,18 @@ class MODEL_TENSOR(IntEnum):
12341238
MODEL_TENSOR.LAYER_OUT_NORM,
12351239
MODEL_TENSOR.CLS,
12361240
],
1241+
MODEL_ARCH.JINA_BERT_V3: [
1242+
MODEL_TENSOR.TOKEN_EMBD,
1243+
MODEL_TENSOR.TOKEN_EMBD_NORM,
1244+
MODEL_TENSOR.TOKEN_TYPES,
1245+
MODEL_TENSOR.OUTPUT_NORM,
1246+
MODEL_TENSOR.ATTN_OUT_NORM,
1247+
MODEL_TENSOR.ATTN_QKV,
1248+
MODEL_TENSOR.ATTN_OUT,
1249+
MODEL_TENSOR.FFN_DOWN,
1250+
MODEL_TENSOR.FFN_UP,
1251+
MODEL_TENSOR.LAYER_OUT_NORM,
1252+
],
12371253
MODEL_ARCH.MPT: [
12381254
MODEL_TENSOR.TOKEN_EMBD,
12391255
MODEL_TENSOR.OUTPUT_NORM,

include/llama.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,24 @@ extern "C" {
553553
struct llama_model * model,
554554
const char * path_lora);
555555

556+
// Functions to access the adapter's GGUF metadata scalar values
557+
// - The functions return the length of the string on success, or -1 on failure
558+
// - The output string is always null-terminated and cleared on failure
559+
// - When retrieving a string, an extra byte must be allocated to account for the null terminator
560+
// - GGUF array values are not supported by these functions
561+
562+
// Get metadata value as a string by key name
563+
LLAMA_API int32_t llama_adapter_meta_val_str(const struct llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size);
564+
565+
// Get the number of metadata key/value pairs
566+
LLAMA_API int32_t llama_adapter_meta_count(const struct llama_adapter_lora * adapter);
567+
568+
// Get metadata key name by index
569+
LLAMA_API int32_t llama_adapter_meta_key_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
570+
571+
// Get metadata value as a string by index
572+
LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
573+
556574
// Manually free a LoRA adapter
557575
// Note: loaded adapters will be free when the associated model is deleted
558576
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);

src/llama-adapter.cpp

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,38 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
163163

164164
// check metadata
165165
{
166+
const gguf_context * gguf_ctx = ctx_gguf.get();
167+
168+
LLAMA_LOG_INFO("%s: Dumping metadata keys/values.\n", __func__);
169+
170+
// get metadata as string
171+
for (int i = 0; i < gguf_get_n_kv(gguf_ctx); i++) {
172+
gguf_type type = gguf_get_kv_type(gguf_ctx, i);
173+
const std::string type_name =
174+
type == GGUF_TYPE_ARRAY
175+
? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(gguf_ctx, i)), gguf_get_arr_n(gguf_ctx, i))
176+
: gguf_type_name(type);
177+
const char * name = gguf_get_key(gguf_ctx, i);
178+
const std::string value = gguf_kv_to_str(gguf_ctx, i);
179+
180+
if (type != GGUF_TYPE_ARRAY) {
181+
adapter.gguf_kv.emplace(name, value);
182+
}
183+
184+
const size_t MAX_VALUE_LEN = 40;
185+
std::string print_value = value.size() > MAX_VALUE_LEN ? format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()) : value;
186+
replace_all(print_value, "\n", "\\n");
187+
188+
LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), print_value.c_str());
189+
}
190+
166191
auto get_kv_str = [&](const std::string & key) -> std::string {
167-
int id = gguf_find_key(ctx_gguf.get(), key.c_str());
168-
return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf.get(), id));
192+
int id = gguf_find_key(gguf_ctx, key.c_str());
193+
return id < 0 ? "" : std::string(gguf_get_val_str(gguf_ctx, id));
169194
};
170195
auto get_kv_f32 = [&](const std::string & key) -> float {
171-
int id = gguf_find_key(ctx_gguf.get(), key.c_str());
172-
return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf.get(), id);
196+
int id = gguf_find_key(gguf_ctx, key.c_str());
197+
return id < 0 ? 0.0f : gguf_get_val_f32(gguf_ctx, id);
173198
};
174199
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
175200

@@ -383,6 +408,45 @@ llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * p
383408
return nullptr;
384409
}
385410

411+
int32_t llama_adapter_meta_val_str(const llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size) {
412+
const auto & it = adapter->gguf_kv.find(key);
413+
if (it == adapter->gguf_kv.end()) {
414+
if (buf_size > 0) {
415+
buf[0] = '\0';
416+
}
417+
return -1;
418+
}
419+
return snprintf(buf, buf_size, "%s", it->second.c_str());
420+
}
421+
422+
int32_t llama_adapter_meta_count(const llama_adapter_lora * adapter) {
423+
return (int)adapter->gguf_kv.size();
424+
}
425+
426+
int32_t llama_adapter_meta_key_by_index(const llama_adapter_lora * adapter, int i, char * buf, size_t buf_size) {
427+
if (i < 0 || i >= (int)adapter->gguf_kv.size()) {
428+
if (buf_size > 0) {
429+
buf[0] = '\0';
430+
}
431+
return -1;
432+
}
433+
auto it = adapter->gguf_kv.begin();
434+
std::advance(it, i);
435+
return snprintf(buf, buf_size, "%s", it->first.c_str());
436+
}
437+
438+
int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size) {
439+
if (i < 0 || i >= (int)adapter->gguf_kv.size()) {
440+
if (buf_size > 0) {
441+
buf[0] = '\0';
442+
}
443+
return -1;
444+
}
445+
auto it = adapter->gguf_kv.begin();
446+
std::advance(it, i);
447+
return snprintf(buf, buf_size, "%s", it->second.c_str());
448+
}
449+
386450
void llama_adapter_lora_free(llama_adapter_lora * adapter) {
387451
delete adapter;
388452
}

src/llama-adapter.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ struct llama_adapter_lora {
6767

6868
float alpha;
6969

70+
// gguf metadata
71+
std::unordered_map<std::string, std::string> gguf_kv;
72+
7073
llama_adapter_lora() = default;
7174
~llama_adapter_lora() = default;
7275

src/llama-arch.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
2222
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
2323
{ LLM_ARCH_NEO_BERT, "neo-bert" },
2424
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
25+
{ LLM_ARCH_JINA_BERT_V3, "jina-bert-v3" },
2526
{ LLM_ARCH_BLOOM, "bloom" },
2627
{ LLM_ARCH_STABLELM, "stablelm" },
2728
{ LLM_ARCH_QWEN, "qwen" },
@@ -234,8 +235,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
234235
{ LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" },
235236
{ LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },
236237

237-
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
238-
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
238+
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
239+
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
240+
{ LLM_KV_ADAPTER_LORA_TASK_NAME, "adapter.lora.task_name" },
241+
{ LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" },
239242

240243
// deprecated
241244
{ LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },
@@ -575,6 +578,20 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
575578
{ LLM_TENSOR_CLS, "cls" },
576579
},
577580
},
581+
{
582+
LLM_ARCH_JINA_BERT_V3,
583+
{
584+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
585+
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
586+
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
587+
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
588+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
589+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
590+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
591+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
592+
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
593+
},
594+
},
578595
{
579596
LLM_ARCH_BLOOM,
580597
{

src/llama-arch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ enum llm_arch {
2626
LLM_ARCH_NOMIC_BERT_MOE,
2727
LLM_ARCH_NEO_BERT,
2828
LLM_ARCH_JINA_BERT_V2,
29+
LLM_ARCH_JINA_BERT_V3,
2930
LLM_ARCH_BLOOM,
3031
LLM_ARCH_STABLELM,
3132
LLM_ARCH_QWEN,
@@ -230,6 +231,8 @@ enum llm_kv {
230231

231232
LLM_KV_ADAPTER_TYPE,
232233
LLM_KV_ADAPTER_LORA_ALPHA,
234+
LLM_KV_ADAPTER_LORA_TASK_NAME,
235+
LLM_KV_ADAPTER_LORA_PROMPT_PREFIX,
233236

234237
LLM_KV_POSNET_EMBEDDING_LENGTH,
235238
LLM_KV_POSNET_BLOCK_COUNT,

0 commit comments

Comments
 (0)