@@ -8947,10 +8947,10 @@ struct llm_build_mamba : public llm_graph_context {
89478947 LLM_NORM_RMS, il);
89488948 cb(cur, "attn_norm", il);
89498949
8950- cur = build_mamba2_layer(gf, cur, state_copy, ubatch, il);
89518950 if (use_mamba2) {
8951+ cur = build_mamba2_layer(this, gf, cur, state_copy, model, ubatch, il);
89528952 } else {
8953- cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
8953+ cur = build_mamba_layer(this, gf, cur, state_copy, model , ubatch, il);
89548954 }
89558955
89568956 if (il == n_layer - 1) {
@@ -8987,27 +8987,32 @@ struct llm_build_mamba : public llm_graph_context {
89878987 ggml_build_forward_expand(gf, cur);
89888988 }
89898989
8990- ggml_tensor * build_mamba_layer(
8991- ggml_cgraph * gf,
8992- ggml_tensor * cur,
8993- ggml_tensor * state_copy,
8994- const llama_ubatch & ubatch,
8995- int il) const {
8996- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state_i *>(mstate);
8990+ // static layer build function that enables other models to borrow this
8991+ // layer logic
8992+ static ggml_tensor * build_mamba_layer(
8993+ const llm_graph_context * self,
8994+ ggml_cgraph * gf,
8995+ ggml_tensor * cur,
8996+ ggml_tensor * state_copy,
8997+ const llama_model & model,
8998+ const llama_ubatch & ubatch,
8999+ int il) {
9000+ const auto * kv_state = self->get_state_recurrent();
89979001
89989002 const auto kv_head = kv_state->get_head();
9003+ auto * ctx0 = self->ctx0;
89999004
9000- const int64_t d_conv = hparams.ssm_d_conv;
9001- const int64_t d_inner = hparams.ssm_d_inner;
9002- const int64_t d_state = hparams.ssm_d_state;
9003- const int64_t dt_rank = hparams.ssm_dt_rank;
9005+ const int64_t d_conv = self-> hparams.ssm_d_conv;
9006+ const int64_t d_inner = self-> hparams.ssm_d_inner;
9007+ const int64_t d_state = self-> hparams.ssm_d_state;
9008+ const int64_t dt_rank = self-> hparams.ssm_dt_rank;
90049009 const int64_t n_head = d_inner;
90059010 const int64_t head_dim = 1;
90069011 const int64_t n_seqs = ubatch.n_seqs;
90079012 // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
9008- const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
9013+ const bool ssm_dt_b_c_rms = self-> hparams.ssm_dt_b_c_rms;
90099014 // Use the same RMS norm as the final layer norm
9010- const float norm_rms_eps = hparams.f_norm_rms_eps;
9015+ const float norm_rms_eps = self-> hparams.f_norm_rms_eps;
90119016
90129017 const int64_t n_seq_tokens = ubatch.n_seq_tokens;
90139018
@@ -9019,16 +9024,16 @@ struct llm_build_mamba : public llm_graph_context {
90199024 ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
90209025
90219026 // (ab)using the KV cache to store the states
9022- ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs);
9027+ ggml_tensor * conv = self-> build_recurrent_state(gf, conv_states_all, state_copy, self-> hparams.n_embd_k_s(), n_seqs);
90239028 conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
9024- ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true);
9029+ ggml_tensor * ssm = self-> build_recurrent_state(gf, ssm_states_all, state_copy, self-> hparams.n_embd_v_s(), n_seqs, true);
90259030 ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
90269031
90279032 // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
90289033 cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
90299034
90309035 // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
9031- ggml_tensor * xz = build_lora_mm(model.layers[il].ssm_in, cur);
9036+ ggml_tensor * xz = self-> build_lora_mm(model.layers[il].ssm_in, cur);
90329037 // split the above in two
90339038 // => {d_inner, n_seq_tokens, n_seqs}
90349039 ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
@@ -9068,7 +9073,7 @@ struct llm_build_mamba : public llm_graph_context {
90689073 // ssm
90699074 {
90709075 // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
9071- ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x);
9076+ ggml_tensor * x_db = self-> build_lora_mm(model.layers[il].ssm_x, x);
90729077 // split
90739078 ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
90749079 ggml_tensor * B = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
@@ -9082,7 +9087,7 @@ struct llm_build_mamba : public llm_graph_context {
90829087 }
90839088
90849089 // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
9085- dt = build_lora_mm(model.layers[il].ssm_dt, dt);
9090+ dt = self-> build_lora_mm(model.layers[il].ssm_dt, dt);
90869091 dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
90879092
90889093 cur = x;
@@ -9108,7 +9113,7 @@ struct llm_build_mamba : public llm_graph_context {
91089113 y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
91099114
91109115 // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
9111- cur = build_lora_mm(model.layers[il].ssm_out, y);
9116+ cur = self-> build_lora_mm(model.layers[il].ssm_out, y);
91129117 }
91139118
91149119 // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
@@ -9118,22 +9123,27 @@ struct llm_build_mamba : public llm_graph_context {
91189123 return cur;
91199124 }
91209125
9121- ggml_tensor * build_mamba2_layer(
9122- ggml_cgraph * gf,
9123- ggml_tensor * cur,
9124- ggml_tensor * state_copy,
9125- const llama_ubatch & ubatch,
9126- int il) const {
9127- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state_i *>(mstate);
9126+ // static layer build function that enables other models to borrow this
9127+ // layer logic
9128+ static ggml_tensor * build_mamba2_layer(
9129+ const llm_graph_context * self,
9130+ ggml_cgraph * gf,
9131+ ggml_tensor * cur,
9132+ ggml_tensor * state_copy,
9133+ const llama_model & model,
9134+ const llama_ubatch & ubatch,
9135+ int il) {
9136+ const auto * kv_state = self->get_state_recurrent();
91289137
91299138 const auto kv_head = kv_state->get_head();
9139+ auto * ctx0 = self->ctx0;
91309140
9131- const int64_t d_conv = hparams.ssm_d_conv;
9132- const int64_t d_inner = hparams.ssm_d_inner;
9133- const int64_t d_state = hparams.ssm_d_state;
9134- const int64_t n_head = hparams.ssm_dt_rank;
9141+ const int64_t d_conv = self-> hparams.ssm_d_conv;
9142+ const int64_t d_inner = self-> hparams.ssm_d_inner;
9143+ const int64_t d_state = self-> hparams.ssm_d_state;
9144+ const int64_t n_head = self-> hparams.ssm_dt_rank;
91359145 const int64_t head_dim = d_inner / n_head;
9136- const int64_t n_group = hparams.ssm_n_group;
9146+ const int64_t n_group = self-> hparams.ssm_n_group;
91379147 const int64_t n_seqs = ubatch.n_seqs;
91389148
91399149 const int64_t n_seq_tokens = ubatch.n_seq_tokens;
@@ -9146,9 +9156,9 @@ struct llm_build_mamba : public llm_graph_context {
91469156 ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
91479157
91489158 // (ab)using the KV cache to store the states
9149- ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs);
9159+ ggml_tensor * conv = self-> build_recurrent_state(gf, conv_states_all, state_copy, self-> hparams.n_embd_k_s(), n_seqs);
91509160 conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
9151- ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true);
9161+ ggml_tensor * ssm = self-> build_recurrent_state(gf, ssm_states_all, state_copy, self-> hparams.n_embd_v_s(), n_seqs, true);
91529162 ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
91539163
91549164 // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@@ -9157,7 +9167,7 @@ struct llm_build_mamba : public llm_graph_context {
91579167 // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
91589168
91599169 // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
9160- ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
9170+ ggml_tensor * zxBCdt = self-> build_lora_mm(model.layers[il].ssm_in, cur);
91619171
91629172 // split the above in three
91639173 ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
@@ -9225,11 +9235,11 @@ struct llm_build_mamba : public llm_graph_context {
92259235
92269236 // grouped RMS norm
92279237 y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
9228- y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
9238+ y = self-> build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
92299239 y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
92309240
92319241 // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
9232- cur = build_lora_mm(model.layers[il].ssm_out, y);
9242+ cur = self-> build_lora_mm(model.layers[il].ssm_out, y);
92339243 }
92349244
92359245 // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
@@ -12562,8 +12572,8 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
1256212572 }
1256312573};
1256412574
12565-
1256612575struct llm_build_granite : public llm_graph_context {
12576+
1256712577 llm_build_granite(
1256812578 const llama_model & model,
1256912579 const llm_graph_params & params,
@@ -12589,7 +12599,6 @@ struct llm_build_granite : public llm_graph_context {
1258912599
1259012600 auto * inp_attn = build_attn_inp_kv_unified();
1259112601
12592- const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
1259312602 for (int il = 0; il < n_layer; ++il) {
1259412603 ggml_tensor * inpSA = inpL;
1259512604
@@ -12600,57 +12609,9 @@ struct llm_build_granite : public llm_graph_context {
1260012609 cb(cur, "attn_norm", il);
1260112610
1260212611 // self-attention
12603- {
12604- // compute Q and K and (optionally) RoPE them
12605- ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
12606- cb(Qcur, "Qcur", il);
12607- if (model.layers[il].bq) {
12608- Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
12609- cb(Qcur, "Qcur", il);
12610- }
12611-
12612- ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
12613- cb(Kcur, "Kcur", il);
12614- if (model.layers[il].bk) {
12615- Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
12616- cb(Kcur, "Kcur", il);
12617- }
12618-
12619- ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
12620- cb(Vcur, "Vcur", il);
12621- if (model.layers[il].bv) {
12622- Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
12623- cb(Vcur, "Vcur", il);
12624- }
12625-
12626- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
12627- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
12628- Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
12629-
12630- if (use_rope) {
12631- ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
12632- Qcur = ggml_rope_ext(
12633- ctx0, Qcur, inp_pos, rope_factors,
12634- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
12635- ext_factor, attn_factor, beta_fast, beta_slow
12636- );
12637-
12638- Kcur = ggml_rope_ext(
12639- ctx0, Kcur, inp_pos, rope_factors,
12640- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
12641- ext_factor, attn_factor, beta_fast, beta_slow
12642- );
12643- }
12644-
12645- cb(Qcur, "Qcur", il);
12646- cb(Kcur, "Kcur", il);
12647- cb(Vcur, "Vcur", il);
12648-
12649- cur = build_attn(inp_attn, gf,
12650- model.layers[il].wo, model.layers[il].bo,
12651- Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
12652- cb(cur, "attn_out", il);
12653- }
12612+ cur = build_attention_layer(
12613+ this, gf, cur, inp_pos, inp_attn,
12614+ model, n_embd_head, use_rope, il);
1265412615
1265512616 if (il == n_layer - 1) {
1265612617 // skip computing output for unused tokens
@@ -12748,6 +12709,74 @@ struct llm_build_granite : public llm_graph_context {
1274812709
1274912710 ggml_build_forward_expand(gf, cur);
1275012711 }
12712+
12713+ // static layer build function that enables other models to borrow this
12714+ // layer logic
12715+ static ggml_tensor * build_attention_layer(
12716+ const llm_graph_context * self,
12717+ ggml_cgraph * gf,
12718+ ggml_tensor * cur,
12719+ ggml_tensor * inp_pos,
12720+ llm_graph_input_attn_kv_unified * inp_attn,
12721+ const llama_model & model,
12722+ const int64_t n_embd_head,
12723+ const bool use_rope,
12724+ const int il) {
12725+
12726+ auto * ctx0 = self->ctx0;
12727+
12728+ // compute Q and K and (optionally) RoPE them
12729+ ggml_tensor * Qcur = self->build_lora_mm(model.layers[il].wq, cur);
12730+ self->cb(Qcur, "Qcur", il);
12731+ if (model.layers[il].bq) {
12732+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
12733+ self->cb(Qcur, "Qcur", il);
12734+ }
12735+
12736+ ggml_tensor * Kcur = self->build_lora_mm(model.layers[il].wk, cur);
12737+ self->cb(Kcur, "Kcur", il);
12738+ if (model.layers[il].bk) {
12739+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
12740+ self->cb(Kcur, "Kcur", il);
12741+ }
12742+
12743+ ggml_tensor * Vcur = self->build_lora_mm(model.layers[il].wv, cur);
12744+ self->cb(Vcur, "Vcur", il);
12745+ if (model.layers[il].bv) {
12746+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
12747+ self->cb(Vcur, "Vcur", il);
12748+ }
12749+
12750+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, self->n_head, self->n_tokens);
12751+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, self->n_head_kv, self->n_tokens);
12752+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, self->n_head_kv, self->n_tokens);
12753+
12754+ if (use_rope) {
12755+ ggml_tensor * rope_factors = model.get_rope_factors(self->cparams, il);
12756+ Qcur = ggml_rope_ext(
12757+ ctx0, Qcur, inp_pos, rope_factors,
12758+ self->n_rot, self->rope_type, self->n_ctx_orig, self->freq_base, self->freq_scale,
12759+ self->ext_factor, self->attn_factor, self->beta_fast, self->beta_slow
12760+ );
12761+
12762+ Kcur = ggml_rope_ext(
12763+ ctx0, Kcur, inp_pos, rope_factors,
12764+ self->n_rot, self->rope_type, self->n_ctx_orig, self->freq_base, self->freq_scale,
12765+ self->ext_factor, self->attn_factor, self->beta_fast, self->beta_slow
12766+ );
12767+ }
12768+
12769+ self->cb(Qcur, "Qcur", il);
12770+ self->cb(Kcur, "Kcur", il);
12771+ self->cb(Vcur, "Vcur", il);
12772+
12773+ const float kq_scale = self->hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : self->hparams.f_attention_scale;
12774+ cur = self->build_attn(inp_attn, gf,
12775+ model.layers[il].wo, model.layers[il].bo,
12776+ Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
12777+ self->cb(cur, "attn_out", il);
12778+ return cur;
12779+ }
1275112780};
1275212781
1275312782// ref: https://github.com/facebookresearch/chameleon
0 commit comments