@@ -10024,7 +10024,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
1002410024 }
1002510025};
1002610026
10027- struct llm_graph_context_mamba : public llm_graph_context {
10027+ struct llm_graph_context_mamba : public virtual llm_graph_context {
1002810028 llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {}
1002910029
1003010030 ggml_tensor * build_mamba_layer(
@@ -10298,7 +10298,8 @@ struct llm_graph_context_mamba : public llm_graph_context {
1029810298};
1029910299
1030010300struct llm_build_mamba : public llm_graph_context_mamba {
10301- llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) {
10301+ llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf)
10302+ : llm_graph_context(params), llm_graph_context_mamba(params) {
1030210303 ggml_tensor * cur;
1030310304 ggml_tensor * inpL;
1030410305
@@ -10355,7 +10356,8 @@ struct llm_build_mamba : public llm_graph_context_mamba {
1035510356};
1035610357
1035710358struct llm_build_jamba : public llm_graph_context_mamba {
10358- llm_build_jamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) {
10359+ llm_build_jamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf)
10360+ : llm_graph_context(params), llm_graph_context_mamba(params) {
1035910361 const int64_t n_embd_head = hparams.n_embd_head_v;
1036010362
1036110363 ggml_tensor * cur;
@@ -13794,81 +13796,10 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
1379413796 }
1379513797};
1379613798
13797- struct llm_build_granite : public llm_graph_context {
13798- llm_build_granite(
13799- const llama_model & model,
13800- const llm_graph_params & params,
13801- ggml_cgraph * gf,
13802- const bool use_rope = true)
13803- : llm_graph_context(params) {
13804-
13805- const int64_t n_embd_head = hparams.n_embd_head_v;
13806-
13807- GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
13808- GGML_ASSERT(n_embd_head == hparams.n_rot);
13809-
13810- ggml_tensor * cur;
13811- ggml_tensor * inpL;
13812-
13813- inpL = build_inp_embd(model.tok_embd);
13814-
13815- // inp_pos - built only if rope enabled
13816- ggml_tensor * inp_pos = nullptr;
13817- if (use_rope) {
13818- inp_pos = build_inp_pos();
13819- }
13820-
13821- auto * inp_attn = build_attn_inp_kv_unified();
13822-
13823- ggml_tensor * inp_out_ids = build_inp_out_ids();
13824-
13825- for (int il = 0; il < n_layer; ++il) {
13826- ggml_tensor * inpSA = inpL;
13799+ struct llm_graph_context_granite : public virtual llm_graph_context {
13800+ llm_graph_context_granite(const llm_graph_params & params) : llm_graph_context(params) {}
1382713801
13828- // norm
13829- cur = build_norm(inpL,
13830- model.layers[il].attn_norm, NULL,
13831- LLM_NORM_RMS, il);
13832- cb(cur, "attn_norm", il);
13833-
13834- // self-attention
13835- cur = build_granite_attention_layer(
13836- gf, cur, inp_pos, inp_attn,
13837- model, n_embd_head, use_rope, il);
13838-
13839- if (il == n_layer - 1 && inp_out_ids) {
13840- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13841- inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13842- }
13843-
13844- // ffn
13845- cur = build_layer_ffn(cur, inpSA, model, il);
13846-
13847- // input for next layer
13848- inpL = cur;
13849- }
13850-
13851- cur = inpL;
13852-
13853- cur = build_norm(cur,
13854- model.output_norm, NULL,
13855- LLM_NORM_RMS, -1);
13856-
13857- cb(cur, "result_norm", -1);
13858- res->t_embd = cur;
13859-
13860- // lm_head
13861- cur = build_lora_mm(model.output, cur);
13862-
13863- // For Granite architectures - scale logits
13864- cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
13865- cb(cur, "result_output", -1);
13866- res->t_logits = cur;
13867-
13868- ggml_build_forward_expand(gf, cur);
13869- }
13870-
13871- ggml_tensor * build_granite_attention_layer(
13802+ ggml_tensor * build_attention_layer(
1387213803 ggml_cgraph * gf,
1387313804 ggml_tensor * cur,
1387413805 ggml_tensor * inp_pos,
@@ -14011,14 +13942,91 @@ struct llm_build_granite : public llm_graph_context {
1401113942 }
1401213943};
1401313944
14014- struct llm_build_granite_hybrid : public llm_graph_context_mamba {
13945+ struct llm_build_granite : public llm_graph_context_granite {
13946+ llm_build_granite(
13947+ const llama_model & model,
13948+ const llm_graph_params & params,
13949+ ggml_cgraph * gf,
13950+ const bool use_rope = true)
13951+ : llm_graph_context(params), llm_graph_context_granite(params) {
13952+
13953+ const int64_t n_embd_head = hparams.n_embd_head_v;
13954+
13955+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
13956+ GGML_ASSERT(n_embd_head == hparams.n_rot);
13957+
13958+ ggml_tensor * cur;
13959+ ggml_tensor * inpL;
13960+
13961+ inpL = build_inp_embd(model.tok_embd);
13962+
13963+ // inp_pos - built only if rope enabled
13964+ ggml_tensor * inp_pos = nullptr;
13965+ if (use_rope) {
13966+ inp_pos = build_inp_pos();
13967+ }
13968+
13969+ auto * inp_attn = build_attn_inp_kv_unified();
13970+
13971+ ggml_tensor * inp_out_ids = build_inp_out_ids();
13972+
13973+ for (int il = 0; il < n_layer; ++il) {
13974+ ggml_tensor * inpSA = inpL;
13975+
13976+ // norm
13977+ cur = build_norm(inpL,
13978+ model.layers[il].attn_norm, NULL,
13979+ LLM_NORM_RMS, il);
13980+ cb(cur, "attn_norm", il);
13981+
13982+ // self-attention
13983+ cur = build_attention_layer(
13984+ gf, cur, inp_pos, inp_attn,
13985+ model, n_embd_head, use_rope, il);
13986+
13987+ if (il == n_layer - 1 && inp_out_ids) {
13988+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13989+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13990+ }
13991+
13992+ // ffn
13993+ cur = build_layer_ffn(cur, inpSA, model, il);
13994+
13995+ // input for next layer
13996+ inpL = cur;
13997+ }
13998+
13999+ cur = inpL;
14000+
14001+ cur = build_norm(cur,
14002+ model.output_norm, NULL,
14003+ LLM_NORM_RMS, -1);
14004+
14005+ cb(cur, "result_norm", -1);
14006+ res->t_embd = cur;
14007+
14008+ // lm_head
14009+ cur = build_lora_mm(model.output, cur);
14010+
14011+ // For Granite architectures - scale logits
14012+ cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
14013+ cb(cur, "result_output", -1);
14014+ res->t_logits = cur;
14015+
14016+ ggml_build_forward_expand(gf, cur);
14017+ }
14018+ };
14019+
14020+ struct llm_build_granite_hybrid : public llm_graph_context_mamba, public llm_graph_context_granite {
1401514021
1401614022 llm_build_granite_hybrid(
1401714023 const llama_model & model,
1401814024 const llm_graph_params & params,
1401914025 ggml_cgraph * gf,
1402014026 const bool use_rope = true) :
14021- llm_graph_context_mamba(params) {
14027+ llm_graph_context(params),
14028+ llm_graph_context_mamba(params),
14029+ llm_graph_context_granite(params) {
1402214030
1402314031 const int64_t n_embd_head = hparams.n_embd_head_v;
1402414032 GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -14056,7 +14064,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
1405614064 cur = build_mamba2_layer(inp_rs, gf, cur, model, ubatch, il);
1405714065 } else {
1405814066 // attention layer //
14059- cur = build_granite_attention_layer (
14067+ cur = build_attention_layer (
1406014068 gf, cur, inp_pos, inp_attn, model,
1406114069 n_embd_head, use_rope, il);
1406214070 }
@@ -14094,148 +14102,6 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
1409414102
1409514103 ggml_build_forward_expand(gf, cur);
1409614104 }
14097-
14098- ggml_tensor * build_granite_attention_layer(
14099- ggml_cgraph * gf,
14100- ggml_tensor * cur,
14101- ggml_tensor * inp_pos,
14102- llm_graph_input_attn_kv_unified * inp,
14103- const llama_model & model,
14104- const int64_t n_embd_head,
14105- const bool use_rope,
14106- const int il) {
14107-
14108- // compute Q and K and (optionally) RoPE them
14109- ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
14110- cb(Qcur, "Qcur", il);
14111- if (model.layers[il].bq) {
14112- Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
14113- cb(Qcur, "Qcur", il);
14114- }
14115-
14116- ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
14117- cb(Kcur, "Kcur", il);
14118- if (model.layers[il].bk) {
14119- Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
14120- cb(Kcur, "Kcur", il);
14121- }
14122-
14123- ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
14124- cb(Vcur, "Vcur", il);
14125- if (model.layers[il].bv) {
14126- Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
14127- cb(Vcur, "Vcur", il);
14128- }
14129-
14130- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens);
14131- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
14132- Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
14133-
14134- if (use_rope) {
14135- ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
14136- Qcur = ggml_rope_ext(
14137- ctx0, Qcur, inp_pos, rope_factors,
14138- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14139- ext_factor, attn_factor, beta_fast, beta_slow
14140- );
14141-
14142- Kcur = ggml_rope_ext(
14143- ctx0, Kcur, inp_pos, rope_factors,
14144- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14145- ext_factor, attn_factor, beta_fast, beta_slow
14146- );
14147- }
14148-
14149- cb(Qcur, "Qcur", il);
14150- cb(Kcur, "Kcur", il);
14151- cb(Vcur, "Vcur", il);
14152-
14153- const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
14154- cur = build_attn(inp, gf,
14155- model.layers[il].wo, model.layers[il].bo,
14156- Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
14157- cb(cur, "attn_out", il);
14158- return cur;
14159- }
14160-
14161- ggml_tensor * build_layer_ffn(
14162- ggml_tensor * cur,
14163- ggml_tensor * inpSA,
14164- const llama_model & model,
14165- const int il) {
14166-
14167- // For Granite architectures - scale residual
14168- if (hparams.f_residual_scale) {
14169- cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
14170- }
14171- ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
14172- cb(ffn_inp, "ffn_inp", il);
14173-
14174- // feed-forward network (non-MoE)
14175- if (model.layers[il].ffn_gate_inp == nullptr) {
14176-
14177- cur = build_norm(ffn_inp,
14178- model.layers[il].ffn_norm, NULL,
14179- LLM_NORM_RMS, il);
14180- cb(cur, "ffn_norm", il);
14181-
14182- cur = build_ffn(cur,
14183- model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
14184- model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
14185- model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
14186- NULL,
14187- LLM_FFN_SILU, LLM_FFN_PAR, il);
14188- cb(cur, "ffn_out", il);
14189-
14190- } else {
14191- // MoE branch
14192- cur = build_norm(ffn_inp,
14193- model.layers[il].ffn_norm, NULL,
14194- LLM_NORM_RMS, il);
14195- cb(cur, "ffn_norm", il);
14196-
14197- ggml_tensor * moe_out = build_moe_ffn(cur,
14198- model.layers[il].ffn_gate_inp,
14199- model.layers[il].ffn_up_exps,
14200- model.layers[il].ffn_gate_exps,
14201- model.layers[il].ffn_down_exps,
14202- nullptr,
14203- n_expert, n_expert_used,
14204- LLM_FFN_SILU, true,
14205- false, 0.0,
14206- LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
14207- il);
14208- cb(moe_out, "ffn_moe_out", il);
14209-
14210- // For Granite MoE Shared
14211- if (hparams.n_ff_shexp > 0) {
14212- ggml_tensor * ffn_shexp = build_ffn(cur,
14213- model.layers[il].ffn_up_shexp, NULL, NULL,
14214- model.layers[il].ffn_gate_shexp, NULL, NULL,
14215- model.layers[il].ffn_down_shexp, NULL, NULL,
14216- NULL,
14217- LLM_FFN_SILU, LLM_FFN_PAR, il);
14218- cb(ffn_shexp, "ffn_shexp", il);
14219-
14220- cur = ggml_add(ctx0, moe_out, ffn_shexp);
14221- cb(cur, "ffn_out", il);
14222- } else {
14223- cur = moe_out;
14224- }
14225- }
14226-
14227- // For Granite architectures - scale residual
14228- if (hparams.f_residual_scale) {
14229- cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
14230- }
14231- cur = ggml_add(ctx0, cur, ffn_inp);
14232- cb(cur, "ffn_out", il);
14233-
14234- cur = build_cvec(cur, il);
14235- cb(cur, "l_out", il);
14236-
14237- return cur;
14238- }
1423914105};
1424014106
1424114107// ref: https://github.com/facebookresearch/chameleon
0 commit comments