Skip to content

Commit 116ee89

Browse files
committed
add inference graph
1 parent 9bb46ee commit 116ee89

File tree

3 files changed

+170
-0
lines changed

3 files changed

+170
-0
lines changed

src/llama-arch.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
2020
{ LLM_ARCH_BERT, "bert" },
2121
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
2222
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
23+
{ LLM_ARCH_NEO_BERT, "neo-bert" },
2324
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
2425
{ LLM_ARCH_BLOOM, "bloom" },
2526
{ LLM_ARCH_STABLELM, "stablelm" },
@@ -494,6 +495,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
494495
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
495496
},
496497
},
498+
{
499+
LLM_ARCH_NEO_BERT,
500+
{
501+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
502+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
503+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
504+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
505+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
506+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
507+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
508+
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
509+
{ LLM_TENSOR_CLS, "cls" },
510+
{ LLM_TENSOR_CLS_OUT, "cls.output" },
511+
},
512+
},
497513
{
498514
LLM_ARCH_JINA_BERT_V2,
499515
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ enum llm_arch {
2424
LLM_ARCH_BERT,
2525
LLM_ARCH_NOMIC_BERT,
2626
LLM_ARCH_NOMIC_BERT_MOE,
27+
LLM_ARCH_NEO_BERT,
2728
LLM_ARCH_JINA_BERT_V2,
2829
LLM_ARCH_BLOOM,
2930
LLM_ARCH_STABLELM,

src/llama-model.cpp

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
738738
}
739739
}
740740
} break;
741+
case LLM_ARCH_NEO_BERT:
742+
{
743+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
744+
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
745+
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
746+
747+
if (hparams.n_layer == 28) {
748+
type = LLM_TYPE_250M;
749+
}
750+
} break;
741751
case LLM_ARCH_BLOOM:
742752
{
743753
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2187,6 +2197,32 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
21872197
layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
21882198
}
21892199
} break;
2200+
case LLM_ARCH_NEO_BERT:
2201+
{
2202+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2203+
2204+
cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
2205+
cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED);
2206+
2207+
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
2208+
cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
2209+
2210+
output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
2211+
2212+
for (int i = 0; i < n_layer; ++i) {
2213+
auto & layer = layers[i];
2214+
2215+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
2216+
2217+
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
2218+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
2219+
2220+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
2221+
2222+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff*2}, 0);
2223+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
2224+
}
2225+
} break;
21902226
case LLM_ARCH_JINA_BERT_V2:
21912227
{
21922228
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings
@@ -6074,6 +6110,117 @@ struct llm_build_bert : public llm_graph_context {
60746110
}
60756111
};
60766112

6113+
struct llm_build_neo_bert : public llm_graph_context {
6114+
llm_build_neo_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
6115+
const int64_t n_embd_head = hparams.n_embd_head_v;
6116+
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
6117+
6118+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
6119+
6120+
ggml_tensor * cur;
6121+
ggml_tensor * inpL;
6122+
ggml_tensor * inp_pos = build_inp_pos();
6123+
6124+
// construct input embeddings (token, type, position)
6125+
inpL = build_inp_embd(model.tok_embd);
6126+
cb(inpL, "inp_embd", -1);
6127+
6128+
auto * inp_attn = build_attn_inp_no_cache();
6129+
6130+
// iterate layers
6131+
for (int il = 0; il < n_layer; ++il) {
6132+
ggml_tensor * cur = inpL;
6133+
6134+
ggml_tensor * Qcur;
6135+
ggml_tensor * Kcur;
6136+
ggml_tensor * Vcur;
6137+
6138+
// pre-norm
6139+
cur = build_norm(inpL,
6140+
model.layers[il].attn_norm, NULL,
6141+
LLM_NORM_RMS, il);
6142+
6143+
// self-attention
6144+
cur = build_lora_mm(model.layers[il].wqkv, cur);
6145+
cb(cur, "wqkv", il);
6146+
6147+
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
6148+
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
6149+
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6150+
6151+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
6152+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
6153+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6154+
6155+
// RoPE
6156+
Qcur = ggml_rope_ext(
6157+
ctx0, Qcur, inp_pos, nullptr,
6158+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6159+
ext_factor, attn_factor, beta_fast, beta_slow
6160+
);
6161+
6162+
Kcur = ggml_rope_ext(
6163+
ctx0, Kcur, inp_pos, nullptr,
6164+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6165+
ext_factor, attn_factor, beta_fast, beta_slow
6166+
);
6167+
6168+
cb(Qcur, "Qcur", il);
6169+
cb(Kcur, "Kcur", il);
6170+
cb(Vcur, "Vcur", il);
6171+
6172+
cur = build_attn(inp_attn, gf,
6173+
model.layers[il].wo, nullptr,
6174+
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6175+
cb(cur, "kqv_out", il);
6176+
6177+
if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
6178+
// skip computing output for unused tokens
6179+
ggml_tensor * inp_out_ids = build_inp_out_ids();
6180+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6181+
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6182+
}
6183+
6184+
// re-add the layer input
6185+
cur = ggml_add(ctx0, cur, inpL);
6186+
6187+
ggml_tensor * ffn_inp = cur;
6188+
cb(ffn_inp, "ffn_inp", il);
6189+
6190+
// pre-norm
6191+
cur = build_norm(ffn_inp,
6192+
model.layers[il].ffn_norm, NULL,
6193+
LLM_NORM_RMS, il);
6194+
cb(cur, "ffn_norm", il);
6195+
6196+
// feed-forward network
6197+
cur = build_ffn(cur,
6198+
model.layers[il].ffn_up,
6199+
NULL, NULL, NULL, NULL, NULL,
6200+
model.layers[il].ffn_down,
6201+
NULL, NULL, NULL,
6202+
LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
6203+
6204+
// attentions bypass the intermediate layer
6205+
cur = ggml_add(ctx0, cur, ffn_inp);
6206+
6207+
// input for next layer
6208+
inpL = cur;
6209+
}
6210+
6211+
cur = inpL;
6212+
6213+
cur = build_norm(cur,
6214+
model.output_norm_enc, NULL,
6215+
LLM_NORM_RMS, -1);
6216+
6217+
cb(cur, "result_embd", -1);
6218+
res->t_embd = cur;
6219+
6220+
ggml_build_forward_expand(gf, cur);
6221+
}
6222+
};
6223+
60776224
struct llm_build_bloom : public llm_graph_context {
60786225
llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
60796226
const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -13202,6 +13349,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1320213349
case LLM_ARCH_JINA_BERT_V2:
1320313350
case LLM_ARCH_NOMIC_BERT:
1320413351
case LLM_ARCH_NOMIC_BERT_MOE:
13352+
case LLM_ARCH_NEO_BERT:
1320513353
case LLM_ARCH_WAVTOKENIZER_DEC:
1320613354
{
1320713355
res = nullptr;
@@ -13310,6 +13458,10 @@ llm_graph_result_ptr llama_model::build_graph(
1331013458
{
1331113459
llm = std::make_unique<llm_build_bert>(*this, params, gf);
1331213460
} break;
13461+
case LLM_ARCH_NEO_BERT:
13462+
{
13463+
llm = std::make_unique<llm_build_neo_bert>(*this, params, gf);
13464+
} break;
1331313465
case LLM_ARCH_BLOOM:
1331413466
{
1331513467
llm = std::make_unique<llm_build_bloom>(*this, params, gf);
@@ -13681,6 +13833,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1368113833
case LLM_ARCH_GRANITE_MOE:
1368213834
case LLM_ARCH_CHAMELEON:
1368313835
case LLM_ARCH_BAILINGMOE:
13836+
case LLM_ARCH_NEO_BERT:
1368413837
return LLAMA_ROPE_TYPE_NORM;
1368513838

1368613839
// the pairs of head values are offset by n_rot/2

0 commit comments

Comments
 (0)