Skip to content

Commit 22d35ac

Browse files
committed
[Model] Refarctor the model arch into llama-model
1 parent 444dfe5 commit 22d35ac

File tree

2 files changed

+181
-9333
lines changed

2 files changed

+181
-9333
lines changed

src/llama-model.cpp

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10907,6 +10907,182 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context {
1090710907
}
1090810908
};
1090910909

10910+
struct llm_build_plm : public llm_graph_context {
10911+
llm_build_plm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
10912+
const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k));
10913+
10914+
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
10915+
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
10916+
const uint32_t kv_lora_rank = hparams.n_lora_kv;
10917+
10918+
ggml_tensor * cur;
10919+
ggml_tensor * inpL;
10920+
10921+
// {n_embd, n_tokens}
10922+
inpL = build_inp_embd(model.tok_embd);
10923+
10924+
// inp_pos - contains the positions
10925+
ggml_tensor * inp_pos = build_inp_pos();
10926+
10927+
auto * inp_attn = build_attn_inp_kv_unified(true, false);
10928+
10929+
for (int il = 0; il < n_layer; ++il) {
10930+
ggml_tensor * inpSA = inpL;
10931+
10932+
// norm
10933+
cur = build_norm(inpL,
10934+
model.layers[il].attn_norm, NULL,
10935+
LLM_NORM_RMS, il);
10936+
cb(cur, "attn_norm", il);
10937+
10938+
// self_attention
10939+
{
10940+
ggml_tensor * q = NULL;
10941+
q = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
10942+
cb(q, "q", il);
10943+
10944+
// split into {n_head * n_embd_head_qk_nope, n_tokens}
10945+
ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
10946+
ggml_row_size(q->type, hparams.n_embd_head_k),
10947+
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
10948+
0);
10949+
cb(q_nope, "q_nope", il);
10950+
10951+
// and {n_head * n_embd_head_qk_rope, n_tokens}
10952+
ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
10953+
ggml_row_size(q->type, hparams.n_embd_head_k),
10954+
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
10955+
ggml_row_size(q->type, n_embd_head_qk_nope));
10956+
cb(q_pe, "q_pe", il);
10957+
10958+
// {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
10959+
ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
10960+
cb(kv_pe_compresseed, "kv_pe_compresseed", il);
10961+
10962+
// split into {kv_lora_rank, n_tokens}
10963+
ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
10964+
kv_pe_compresseed->nb[1],
10965+
0);
10966+
cb(kv_compressed, "kv_compressed", il);
10967+
10968+
// and {n_embd_head_qk_rope, n_tokens}
10969+
ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
10970+
kv_pe_compresseed->nb[1],
10971+
kv_pe_compresseed->nb[1],
10972+
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
10973+
cb(k_pe, "k_pe", il);
10974+
10975+
// TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
10976+
kv_compressed = ggml_cont(ctx0, kv_compressed);
10977+
kv_compressed = build_norm(kv_compressed,
10978+
model.layers[il].attn_kv_a_norm, NULL,
10979+
LLM_NORM_RMS, il);
10980+
cb(kv_compressed, "kv_compressed", il);
10981+
10982+
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
10983+
ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
10984+
cb(kv, "kv", il);
10985+
10986+
// split into {n_head * n_embd_head_qk_nope, n_tokens}
10987+
ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
10988+
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
10989+
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
10990+
0);
10991+
cb(k_nope, "k_nope", il);
10992+
10993+
// and {n_head * n_embd_head_v, n_tokens}
10994+
ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
10995+
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
10996+
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
10997+
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
10998+
cb(v_states, "v_states", il);
10999+
11000+
v_states = ggml_cont(ctx0, v_states);
11001+
cb(v_states, "v_states", il);
11002+
11003+
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
11004+
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
11005+
0);
11006+
cb(v_states, "v_states", il);
11007+
11008+
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
11009+
q_pe = ggml_rope_ext(
11010+
ctx0, q_pe, inp_pos, nullptr,
11011+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
11012+
ext_factor, attn_factor, beta_fast, beta_slow
11013+
);
11014+
cb(q_pe, "q_pe", il);
11015+
11016+
// shared RoPE key
11017+
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
11018+
k_pe = ggml_rope_ext(
11019+
ctx0, k_pe, inp_pos, nullptr,
11020+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
11021+
ext_factor, attn_factor, beta_fast, beta_slow
11022+
);
11023+
cb(k_pe, "k_pe", il);
11024+
11025+
ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
11026+
cb(q_states, "q_states", il);
11027+
11028+
ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
11029+
cb(k_states, "k_states", il);
11030+
11031+
cur = build_attn(inp_attn, gf,
11032+
model.layers[il].wo, NULL,
11033+
q_states, k_states, v_states, nullptr, kq_scale, il);
11034+
}
11035+
11036+
if (il == n_layer - 1) {
11037+
// skip computing output for unused tokens
11038+
ggml_tensor * inp_out_ids = build_inp_out_ids();
11039+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
11040+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
11041+
}
11042+
11043+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
11044+
cb(ffn_inp, "ffn_inp", il);
11045+
11046+
cur = build_norm(ffn_inp,
11047+
model.layers[il].ffn_norm, NULL,
11048+
LLM_NORM_RMS, il);
11049+
cb(cur, "ffn_norm", il);
11050+
11051+
cur = build_ffn(cur,
11052+
model.layers[il].ffn_up, NULL, NULL,
11053+
NULL, NULL, NULL,
11054+
model.layers[il].ffn_down, NULL, NULL,
11055+
NULL,
11056+
LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
11057+
cb(cur, "ffn_out", il);
11058+
11059+
cur = ggml_add(ctx0, cur, ffn_inp);
11060+
11061+
cur = build_cvec(cur, il);
11062+
cb(cur, "l_out", il);
11063+
11064+
// input for next layer
11065+
inpL = cur;
11066+
}
11067+
11068+
cur = inpL;
11069+
11070+
cur = build_norm(cur,
11071+
model.output_norm, NULL,
11072+
LLM_NORM_RMS, -1);
11073+
11074+
cb(cur, "result_norm", -1);
11075+
res->t_embd = cur;
11076+
11077+
cur = build_lora_mm(model.output, cur);
11078+
11079+
cb(cur, "result_output", -1);
11080+
res->t_logits = cur;
11081+
11082+
ggml_build_forward_expand(gf, cur);
11083+
}
11084+
};
11085+
1091011086
llama_memory_i * llama_model::create_memory() const {
1091111087
llama_memory_i * res;
1091211088

@@ -11168,6 +11344,10 @@ llm_graph_result_ptr llama_model::build_graph(
1116811344
{
1116911345
llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params, gf);
1117011346
} break;
11347+
case LLM_ARCH_PLM:
11348+
{
11349+
llm = std::make_unique<llm_build_plm>(*this, params, gf);
11350+
} break;
1117111351
default:
1117211352
GGML_ABORT("fatal error");
1117311353
}

0 commit comments

Comments
 (0)