Skip to content

Commit f864a1c

Browse files
committed
feat: Isolate mamba(2) and granite attention layer building in static methods
This will allow these layer-builder methods to be used from other build structs without complex inheritance. Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 6fec9e3 commit f864a1c

File tree

1 file changed

+120
-91
lines changed

1 file changed

+120
-91
lines changed

src/llama-model.cpp

Lines changed: 120 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -8947,10 +8947,10 @@ struct llm_build_mamba : public llm_graph_context {
89478947
LLM_NORM_RMS, il);
89488948
cb(cur, "attn_norm", il);
89498949

8950-
cur = build_mamba2_layer(gf, cur, state_copy, ubatch, il);
89518950
if (use_mamba2) {
8951+
cur = build_mamba2_layer(this, gf, cur, state_copy, model, ubatch, il);
89528952
} else {
8953-
cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
8953+
cur = build_mamba_layer(this, gf, cur, state_copy, model, ubatch, il);
89548954
}
89558955

89568956
if (il == n_layer - 1) {
@@ -8987,27 +8987,32 @@ struct llm_build_mamba : public llm_graph_context {
89878987
ggml_build_forward_expand(gf, cur);
89888988
}
89898989

8990-
ggml_tensor * build_mamba_layer(
8991-
ggml_cgraph * gf,
8992-
ggml_tensor * cur,
8993-
ggml_tensor * state_copy,
8994-
const llama_ubatch & ubatch,
8995-
int il) const {
8996-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state_i *>(mstate);
8990+
// static layer build function that enables other models to borrow this
8991+
// layer logic
8992+
static ggml_tensor * build_mamba_layer(
8993+
const llm_graph_context * self,
8994+
ggml_cgraph * gf,
8995+
ggml_tensor * cur,
8996+
ggml_tensor * state_copy,
8997+
const llama_model & model,
8998+
const llama_ubatch & ubatch,
8999+
int il) {
9000+
const auto * kv_state = self->get_state_recurrent();
89979001

89989002
const auto kv_head = kv_state->get_head();
9003+
auto * ctx0 = self->ctx0;
89999004

9000-
const int64_t d_conv = hparams.ssm_d_conv;
9001-
const int64_t d_inner = hparams.ssm_d_inner;
9002-
const int64_t d_state = hparams.ssm_d_state;
9003-
const int64_t dt_rank = hparams.ssm_dt_rank;
9005+
const int64_t d_conv = self->hparams.ssm_d_conv;
9006+
const int64_t d_inner = self->hparams.ssm_d_inner;
9007+
const int64_t d_state = self->hparams.ssm_d_state;
9008+
const int64_t dt_rank = self->hparams.ssm_dt_rank;
90049009
const int64_t n_head = d_inner;
90059010
const int64_t head_dim = 1;
90069011
const int64_t n_seqs = ubatch.n_seqs;
90079012
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
9008-
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
9013+
const bool ssm_dt_b_c_rms = self->hparams.ssm_dt_b_c_rms;
90099014
// Use the same RMS norm as the final layer norm
9010-
const float norm_rms_eps = hparams.f_norm_rms_eps;
9015+
const float norm_rms_eps = self->hparams.f_norm_rms_eps;
90119016

90129017
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
90139018

@@ -9019,16 +9024,16 @@ struct llm_build_mamba : public llm_graph_context {
90199024
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
90209025

90219026
// (ab)using the KV cache to store the states
9022-
ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs);
9027+
ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(), n_seqs);
90239028
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
9024-
ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true);
9029+
ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(), n_seqs, true);
90259030
ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
90269031

90279032
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
90289033
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
90299034

90309035
// {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
9031-
ggml_tensor * xz = build_lora_mm(model.layers[il].ssm_in, cur);
9036+
ggml_tensor * xz = self->build_lora_mm(model.layers[il].ssm_in, cur);
90329037
// split the above in two
90339038
// => {d_inner, n_seq_tokens, n_seqs}
90349039
ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
@@ -9068,7 +9073,7 @@ struct llm_build_mamba : public llm_graph_context {
90689073
// ssm
90699074
{
90709075
// {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
9071-
ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x);
9076+
ggml_tensor * x_db = self->build_lora_mm(model.layers[il].ssm_x, x);
90729077
// split
90739078
ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
90749079
ggml_tensor * B = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
@@ -9082,7 +9087,7 @@ struct llm_build_mamba : public llm_graph_context {
90829087
}
90839088

90849089
// {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
9085-
dt = build_lora_mm(model.layers[il].ssm_dt, dt);
9090+
dt = self->build_lora_mm(model.layers[il].ssm_dt, dt);
90869091
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
90879092

90889093
cur = x;
@@ -9108,7 +9113,7 @@ struct llm_build_mamba : public llm_graph_context {
91089113
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
91099114

91109115
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
9111-
cur = build_lora_mm(model.layers[il].ssm_out, y);
9116+
cur = self->build_lora_mm(model.layers[il].ssm_out, y);
91129117
}
91139118

91149119
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
@@ -9118,22 +9123,27 @@ struct llm_build_mamba : public llm_graph_context {
91189123
return cur;
91199124
}
91209125

9121-
ggml_tensor * build_mamba2_layer(
9122-
ggml_cgraph * gf,
9123-
ggml_tensor * cur,
9124-
ggml_tensor * state_copy,
9125-
const llama_ubatch & ubatch,
9126-
int il) const {
9127-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state_i *>(mstate);
9126+
// static layer build function that enables other models to borrow this
9127+
// layer logic
9128+
static ggml_tensor * build_mamba2_layer(
9129+
const llm_graph_context * self,
9130+
ggml_cgraph * gf,
9131+
ggml_tensor * cur,
9132+
ggml_tensor * state_copy,
9133+
const llama_model & model,
9134+
const llama_ubatch & ubatch,
9135+
int il) {
9136+
const auto * kv_state = self->get_state_recurrent();
91289137

91299138
const auto kv_head = kv_state->get_head();
9139+
auto * ctx0 = self->ctx0;
91309140

9131-
const int64_t d_conv = hparams.ssm_d_conv;
9132-
const int64_t d_inner = hparams.ssm_d_inner;
9133-
const int64_t d_state = hparams.ssm_d_state;
9134-
const int64_t n_head = hparams.ssm_dt_rank;
9141+
const int64_t d_conv = self->hparams.ssm_d_conv;
9142+
const int64_t d_inner = self->hparams.ssm_d_inner;
9143+
const int64_t d_state = self->hparams.ssm_d_state;
9144+
const int64_t n_head = self->hparams.ssm_dt_rank;
91359145
const int64_t head_dim = d_inner / n_head;
9136-
const int64_t n_group = hparams.ssm_n_group;
9146+
const int64_t n_group = self->hparams.ssm_n_group;
91379147
const int64_t n_seqs = ubatch.n_seqs;
91389148

91399149
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
@@ -9146,9 +9156,9 @@ struct llm_build_mamba : public llm_graph_context {
91469156
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
91479157

91489158
// (ab)using the KV cache to store the states
9149-
ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs);
9159+
ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(), n_seqs);
91509160
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
9151-
ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true);
9161+
ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(), n_seqs, true);
91529162
ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
91539163

91549164
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@@ -9157,7 +9167,7 @@ struct llm_build_mamba : public llm_graph_context {
91579167
// d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
91589168

91599169
// {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
9160-
ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
9170+
ggml_tensor * zxBCdt = self->build_lora_mm(model.layers[il].ssm_in, cur);
91619171

91629172
// split the above in three
91639173
ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
@@ -9225,11 +9235,11 @@ struct llm_build_mamba : public llm_graph_context {
92259235

92269236
// grouped RMS norm
92279237
y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
9228-
y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
9238+
y = self->build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
92299239
y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
92309240

92319241
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
9232-
cur = build_lora_mm(model.layers[il].ssm_out, y);
9242+
cur = self->build_lora_mm(model.layers[il].ssm_out, y);
92339243
}
92349244

92359245
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
@@ -12562,8 +12572,8 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
1256212572
}
1256312573
};
1256412574

12565-
1256612575
struct llm_build_granite : public llm_graph_context {
12576+
1256712577
llm_build_granite(
1256812578
const llama_model & model,
1256912579
const llm_graph_params & params,
@@ -12589,7 +12599,6 @@ struct llm_build_granite : public llm_graph_context {
1258912599

1259012600
auto * inp_attn = build_attn_inp_kv_unified();
1259112601

12592-
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
1259312602
for (int il = 0; il < n_layer; ++il) {
1259412603
ggml_tensor * inpSA = inpL;
1259512604

@@ -12600,57 +12609,9 @@ struct llm_build_granite : public llm_graph_context {
1260012609
cb(cur, "attn_norm", il);
1260112610

1260212611
// self-attention
12603-
{
12604-
// compute Q and K and (optionally) RoPE them
12605-
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
12606-
cb(Qcur, "Qcur", il);
12607-
if (model.layers[il].bq) {
12608-
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
12609-
cb(Qcur, "Qcur", il);
12610-
}
12611-
12612-
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
12613-
cb(Kcur, "Kcur", il);
12614-
if (model.layers[il].bk) {
12615-
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
12616-
cb(Kcur, "Kcur", il);
12617-
}
12618-
12619-
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
12620-
cb(Vcur, "Vcur", il);
12621-
if (model.layers[il].bv) {
12622-
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
12623-
cb(Vcur, "Vcur", il);
12624-
}
12625-
12626-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
12627-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
12628-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
12629-
12630-
if (use_rope) {
12631-
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
12632-
Qcur = ggml_rope_ext(
12633-
ctx0, Qcur, inp_pos, rope_factors,
12634-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
12635-
ext_factor, attn_factor, beta_fast, beta_slow
12636-
);
12637-
12638-
Kcur = ggml_rope_ext(
12639-
ctx0, Kcur, inp_pos, rope_factors,
12640-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
12641-
ext_factor, attn_factor, beta_fast, beta_slow
12642-
);
12643-
}
12644-
12645-
cb(Qcur, "Qcur", il);
12646-
cb(Kcur, "Kcur", il);
12647-
cb(Vcur, "Vcur", il);
12648-
12649-
cur = build_attn(inp_attn, gf,
12650-
model.layers[il].wo, model.layers[il].bo,
12651-
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
12652-
cb(cur, "attn_out", il);
12653-
}
12612+
cur = build_attention_layer(
12613+
this, gf, cur, inp_pos, inp_attn,
12614+
model, n_embd_head, use_rope, il);
1265412615

1265512616
if (il == n_layer - 1) {
1265612617
// skip computing output for unused tokens
@@ -12748,6 +12709,74 @@ struct llm_build_granite : public llm_graph_context {
1274812709

1274912710
ggml_build_forward_expand(gf, cur);
1275012711
}
12712+
12713+
// static layer build function that enables other models to borrow this
12714+
// layer logic
12715+
static ggml_tensor * build_attention_layer(
12716+
const llm_graph_context * self,
12717+
ggml_cgraph * gf,
12718+
ggml_tensor * cur,
12719+
ggml_tensor * inp_pos,
12720+
llm_graph_input_attn_kv_unified * inp_attn,
12721+
const llama_model & model,
12722+
const int64_t n_embd_head,
12723+
const bool use_rope,
12724+
const int il) {
12725+
12726+
auto * ctx0 = self->ctx0;
12727+
12728+
// compute Q and K and (optionally) RoPE them
12729+
ggml_tensor * Qcur = self->build_lora_mm(model.layers[il].wq, cur);
12730+
self->cb(Qcur, "Qcur", il);
12731+
if (model.layers[il].bq) {
12732+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
12733+
self->cb(Qcur, "Qcur", il);
12734+
}
12735+
12736+
ggml_tensor * Kcur = self->build_lora_mm(model.layers[il].wk, cur);
12737+
self->cb(Kcur, "Kcur", il);
12738+
if (model.layers[il].bk) {
12739+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
12740+
self->cb(Kcur, "Kcur", il);
12741+
}
12742+
12743+
ggml_tensor * Vcur = self->build_lora_mm(model.layers[il].wv, cur);
12744+
self->cb(Vcur, "Vcur", il);
12745+
if (model.layers[il].bv) {
12746+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
12747+
self->cb(Vcur, "Vcur", il);
12748+
}
12749+
12750+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, self->n_head, self->n_tokens);
12751+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, self->n_head_kv, self->n_tokens);
12752+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, self->n_head_kv, self->n_tokens);
12753+
12754+
if (use_rope) {
12755+
ggml_tensor * rope_factors = model.get_rope_factors(self->cparams, il);
12756+
Qcur = ggml_rope_ext(
12757+
ctx0, Qcur, inp_pos, rope_factors,
12758+
self->n_rot, self->rope_type, self->n_ctx_orig, self->freq_base, self->freq_scale,
12759+
self->ext_factor, self->attn_factor, self->beta_fast, self->beta_slow
12760+
);
12761+
12762+
Kcur = ggml_rope_ext(
12763+
ctx0, Kcur, inp_pos, rope_factors,
12764+
self->n_rot, self->rope_type, self->n_ctx_orig, self->freq_base, self->freq_scale,
12765+
self->ext_factor, self->attn_factor, self->beta_fast, self->beta_slow
12766+
);
12767+
}
12768+
12769+
self->cb(Qcur, "Qcur", il);
12770+
self->cb(Kcur, "Kcur", il);
12771+
self->cb(Vcur, "Vcur", il);
12772+
12773+
const float kq_scale = self->hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : self->hparams.f_attention_scale;
12774+
cur = self->build_attn(inp_attn, gf,
12775+
model.layers[il].wo, model.layers[il].bo,
12776+
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
12777+
self->cb(cur, "attn_out", il);
12778+
return cur;
12779+
}
1275112780
};
1275212781

1275312782
// ref: https://github.com/facebookresearch/chameleon

0 commit comments

Comments
 (0)