Skip to content

Commit 4ac1380

Browse files
authored
initial jina-embeddings-v3 support
1 parent ba51f89 commit 4ac1380

File tree

4 files changed

+65
-2
lines changed

4 files changed

+65
-2
lines changed

src/llama-arch.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
2121
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
2222
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
2323
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
24+
{ LLM_ARCH_JINA_BERT_V3, "jina-bert-v3" },
2425
{ LLM_ARCH_BLOOM, "bloom" },
2526
{ LLM_ARCH_STABLELM, "stablelm" },
2627
{ LLM_ARCH_QWEN, "qwen" },
@@ -513,6 +514,20 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
513514
{ LLM_TENSOR_CLS, "cls" },
514515
},
515516
},
517+
{
518+
LLM_ARCH_JINA_BERT_V3,
519+
{
520+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
521+
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
522+
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
523+
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
524+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
525+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
526+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
527+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
528+
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
529+
},
530+
},
516531
{
517532
LLM_ARCH_BLOOM,
518533
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ enum llm_arch {
2525
LLM_ARCH_NOMIC_BERT,
2626
LLM_ARCH_NOMIC_BERT_MOE,
2727
LLM_ARCH_JINA_BERT_V2,
28+
LLM_ARCH_JINA_BERT_V3,
2829
LLM_ARCH_BLOOM,
2930
LLM_ARCH_STABLELM,
3031
LLM_ARCH_QWEN,

src/llama-model.cpp

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ const char * llm_type_name(llm_type type) {
4141
case LLM_TYPE_410M: return "410M";
4242
case LLM_TYPE_450M: return "450M";
4343
case LLM_TYPE_475M: return "475M";
44+
case LLM_TYPE_558M: return "558M";
4445
case LLM_TYPE_770M: return "770M";
4546
case LLM_TYPE_780M: return "780M";
4647
case LLM_TYPE_0_5B: return "0.5B";
@@ -710,6 +711,18 @@ void llama_model::load_hparams(llama_model_loader & ml) {
710711
default: type = LLM_TYPE_UNKNOWN;
711712
}
712713
} break;
714+
case LLM_ARCH_JINA_BERT_V3:
715+
{
716+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
717+
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
718+
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
719+
720+
switch (hparams.n_layer) {
721+
case 24:
722+
type = LLM_TYPE_558M; break;
723+
default: type = LLM_TYPE_UNKNOWN;
724+
}
725+
} break;
713726
case LLM_ARCH_NOMIC_BERT:
714727
case LLM_ARCH_NOMIC_BERT_MOE:
715728
{
@@ -2215,6 +2228,36 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
22152228
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
22162229
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
22172230

2231+
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
2232+
layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
2233+
}
2234+
} break;
2235+
case LLM_ARCH_JINA_BERT_V3:
2236+
{
2237+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2238+
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0);
2239+
2240+
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
2241+
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
2242+
2243+
for (int i = 0; i < n_layer; ++i) {
2244+
auto & layer = layers[i];
2245+
2246+
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
2247+
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
2248+
2249+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
2250+
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
2251+
2252+
layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
2253+
layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0);
2254+
2255+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
2256+
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
2257+
2258+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
2259+
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
2260+
22182261
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
22192262
layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
22202263
}
@@ -5931,7 +5974,7 @@ struct llm_build_bert : public llm_graph_context {
59315974
cur = build_lora_mm(model.layers[il].wqkv, cur);
59325975
cb(cur, "wqkv", il);
59335976

5934-
if (model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
5977+
if (model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) {
59355978
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
59365979
cb(cur, "bqkv", il);
59375980
}
@@ -6003,7 +6046,7 @@ struct llm_build_bert : public llm_graph_context {
60036046
0.0f,
60046047
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
60056048
cb(cur, "ffn_moe_out", il);
6006-
} else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
6049+
} else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) {
60076050
cur = build_ffn(cur,
60086051
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
60096052
NULL, NULL, NULL,
@@ -13187,6 +13230,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1318713230
switch (arch) {
1318813231
case LLM_ARCH_BERT:
1318913232
case LLM_ARCH_JINA_BERT_V2:
13233+
case LLM_ARCH_JINA_BERT_V3:
1319013234
case LLM_ARCH_NOMIC_BERT:
1319113235
case LLM_ARCH_NOMIC_BERT_MOE:
1319213236
{
@@ -13292,6 +13336,7 @@ llm_graph_result_ptr llama_model::build_graph(
1329213336
} break;
1329313337
case LLM_ARCH_BERT:
1329413338
case LLM_ARCH_JINA_BERT_V2:
13339+
case LLM_ARCH_JINA_BERT_V3:
1329513340
case LLM_ARCH_NOMIC_BERT:
1329613341
case LLM_ARCH_NOMIC_BERT_MOE:
1329713342
{
@@ -13658,6 +13703,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1365813703
case LLM_ARCH_GROK:
1365913704
case LLM_ARCH_DBRX:
1366013705
case LLM_ARCH_BERT:
13706+
case LLM_ARCH_JINA_BERT_V3:
1366113707
case LLM_ARCH_NOMIC_BERT:
1366213708
case LLM_ARCH_NOMIC_BERT_MOE:
1366313709
case LLM_ARCH_STABLELM:

src/llama-model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ enum llm_type {
3737
LLM_TYPE_410M,
3838
LLM_TYPE_450M,
3939
LLM_TYPE_475M,
40+
LLM_TYPE_558M,
4041
LLM_TYPE_770M,
4142
LLM_TYPE_780M,
4243
LLM_TYPE_0_5B,

0 commit comments

Comments
 (0)