Skip to content

Commit 6ffa2d3

Browse files
committed
feat: First (broken) pass at end-to-end Bamba implementation
It generates (garbage) tokens! Still lots of debugging to do. Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 6958888 commit 6ffa2d3

File tree

1 file changed

+282
-0
lines changed

1 file changed

+282
-0
lines changed

src/llama-model.cpp

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,6 +1435,49 @@ void llama_model::load_hparams(llama_model_loader & ml) {
14351435
// For Granite MoE Shared
14361436
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false);
14371437
} break;
1438+
case LLM_ARCH_BAMBA:
1439+
{
1440+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1441+
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /* required */ false);
1442+
ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /* required */ false);
1443+
ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /* required */ false);
1444+
ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, /* required */ false);
1445+
1446+
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
1447+
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
1448+
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
1449+
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
1450+
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
1451+
1452+
// Zero-out n_head_arr and n_head_kv_arr since SSM layers don't
1453+
// have attention heads. We'll set them correctly below once we
1454+
// know which layers are attention layers
1455+
// NOTE: It's important that this happens after n_embd_head_[kv]
1456+
// are set above!
1457+
const auto n_head_attn = hparams.n_head();
1458+
const auto n_head_kv_attn = hparams.n_head_kv();
1459+
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
1460+
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
1461+
1462+
// Attention params
1463+
std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true);
1464+
std::vector<uint32_t> attn_layer_indices;
1465+
ml.get_arr(LLM_KV_ATTENTION_LAYER_INDICES, attn_layer_indices);
1466+
for (const auto attn_idx : attn_layer_indices) {
1467+
GGML_ASSERT(attn_idx < hparams.n_layer);
1468+
hparams.recurrent_layer_arr[attn_idx] = false;
1469+
// Correctly set n_head and n_head_kv for attention layers
1470+
hparams.n_head_arr[attn_idx] = n_head_attn;
1471+
hparams.n_head_kv_arr[attn_idx] = n_head_kv_attn;
1472+
}
1473+
1474+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1475+
1476+
switch (hparams.n_layer) {
1477+
// TODO: Add llm type label (not sure this is useful)
1478+
default: type = LLM_TYPE_UNKNOWN;
1479+
}
1480+
} break;
14381481
case LLM_ARCH_CHAMELEON:
14391482
{
14401483
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -3049,6 +3092,83 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
30493092
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
30503093
}
30513094
} break;
3095+
case LLM_ARCH_BAMBA:
3096+
{
3097+
// mamba2 Mixer SSM params
3098+
// NOTE: int64_t for tensor dimensions
3099+
const int64_t d_conv = hparams.ssm_d_conv;
3100+
const int64_t d_inner = hparams.ssm_d_inner;
3101+
const int64_t d_state = hparams.ssm_d_state;
3102+
const int64_t n_ssm_head = hparams.ssm_dt_rank;
3103+
const int64_t n_group = hparams.ssm_n_group;
3104+
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head;
3105+
3106+
// only an expansion factor of 2 is supported for now
3107+
GGML_ASSERT(2 * n_embd == d_inner);
3108+
3109+
// embeddings
3110+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
3111+
3112+
// output
3113+
{
3114+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
3115+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
3116+
// if output is NULL, init from the input tok embed, duplicated to allow offloading
3117+
if (output == NULL) {
3118+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
3119+
}
3120+
}
3121+
3122+
for (int i = 0; i < n_layer; ++i) {
3123+
auto & layer = layers[i];
3124+
3125+
// norm
3126+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
3127+
3128+
if (hparams.recurrent_layer(i)) {
3129+
// ssm layers
3130+
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0);
3131+
layer.ssm_in_b = create_tensor(tn(LLM_TENSOR_SSM_IN, "bias", i), {n_embd, d_in_proj}, llama_model_loader::TENSOR_NOT_REQUIRED);
3132+
3133+
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
3134+
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, llama_model_loader::TENSOR_NOT_REQUIRED);
3135+
3136+
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0);
3137+
3138+
// no "weight" suffix for these
3139+
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0);
3140+
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0);
3141+
3142+
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
3143+
3144+
// out_proj
3145+
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
3146+
} else {
3147+
// attention layers (with optional bias)
3148+
const int64_t n_head_i = hparams.n_head(i);
3149+
const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i);
3150+
const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i);
3151+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0);
3152+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0);
3153+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0);
3154+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0);
3155+
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
3156+
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, llama_model_loader::TENSOR_NOT_REQUIRED);
3157+
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, llama_model_loader::TENSOR_NOT_REQUIRED);
3158+
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
3159+
}
3160+
3161+
// feed forward (w/ optional biases)
3162+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
3163+
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
3164+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
3165+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
3166+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
3167+
layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
3168+
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
3169+
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
3170+
}
3171+
} break;
30523172
case LLM_ARCH_XVERSE:
30533173
{
30543174
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -12781,6 +12901,160 @@ struct llm_build_granite : public llm_graph_context {
1278112901
}
1278212902
};
1278312903

12904+
struct llm_build_hybrid_mamba : public llm_graph_context {
12905+
12906+
llm_build_hybrid_mamba(
12907+
const llama_model & model,
12908+
const llm_graph_params & params,
12909+
ggml_cgraph * gf,
12910+
const bool use_mamba2 = true,
12911+
const bool use_rope = true)
12912+
: llm_graph_context(params) {
12913+
const int64_t n_embd_head = hparams.n_embd_head_v;
12914+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
12915+
12916+
ggml_tensor * cur;
12917+
ggml_tensor * inpL;
12918+
12919+
inpL = build_inp_embd(model.tok_embd);
12920+
12921+
// Build the inputs in the recurrent cache
12922+
ggml_tensor * state_copy = build_inp_s_copy();
12923+
12924+
// Build the inputs in the attention cache
12925+
auto * inp_attn = build_attn_inp_kv_unified();
12926+
12927+
// Positional embeddings populated if rope enabled
12928+
ggml_tensor * inp_pos = nullptr;
12929+
if (use_rope) {
12930+
inp_pos = build_inp_pos();
12931+
}
12932+
12933+
// Extract the recurrent cache from the hybrid parent
12934+
const auto * kv_recurrent = static_cast<const llama_kv_cache_hybrid *>(memory)->get_child_cache<llama_kv_cache_recurrent>();
12935+
GGML_ASSERT(kv_recurrent);
12936+
12937+
for (int il = 0; il < n_layer; ++il) {
12938+
struct ggml_tensor * inpSA = inpL;
12939+
12940+
// norm
12941+
cur = build_norm(inpL,
12942+
model.layers[il].attn_norm, NULL,
12943+
LLM_NORM_RMS, il);
12944+
cb(cur, "attn_norm", il);
12945+
12946+
if (hparams.recurrent_layer(il)) {
12947+
// ssm layer //
12948+
if (use_mamba2) {
12949+
cur = llm_build_mamba::build_mamba2_layer(this, gf, cur, state_copy, kv_recurrent, model, ubatch, il);
12950+
} else {
12951+
cur = llm_build_mamba::build_mamba_layer(this, gf, cur, state_copy, kv_recurrent, model, ubatch, il);
12952+
}
12953+
} else {
12954+
// attention layer //
12955+
cur = llm_build_granite::build_attention_layer(
12956+
this, gf, cur, inp_pos, inp_attn,
12957+
model, n_embd_head, use_rope, il);
12958+
}
12959+
12960+
if (il == n_layer - 1) {
12961+
// skip computing output for unused tokens
12962+
ggml_tensor * inp_out_ids = build_inp_out_ids();
12963+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
12964+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
12965+
}
12966+
12967+
// For Granite architectures - scale residual
12968+
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
12969+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
12970+
cb(ffn_inp, "ffn_inp", il);
12971+
12972+
// feed-forward network (non-MoE)
12973+
if (model.layers[il].ffn_gate_inp == nullptr) {
12974+
12975+
cur = build_norm(ffn_inp,
12976+
model.layers[il].ffn_norm, NULL,
12977+
LLM_NORM_RMS, il);
12978+
cb(cur, "ffn_norm", il);
12979+
12980+
cur = build_ffn(cur,
12981+
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
12982+
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
12983+
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
12984+
NULL,
12985+
LLM_FFN_SILU, LLM_FFN_PAR, il);
12986+
cb(cur, "ffn_out", il);
12987+
12988+
} else {
12989+
// MoE branch
12990+
cur = build_norm(ffn_inp,
12991+
model.layers[il].ffn_norm, NULL,
12992+
LLM_NORM_RMS, il);
12993+
cb(cur, "ffn_norm", il);
12994+
12995+
ggml_tensor * moe_out = build_moe_ffn(cur,
12996+
model.layers[il].ffn_gate_inp,
12997+
model.layers[il].ffn_up_exps,
12998+
model.layers[il].ffn_gate_exps,
12999+
model.layers[il].ffn_down_exps,
13000+
nullptr,
13001+
n_expert, n_expert_used,
13002+
LLM_FFN_SILU, true,
13003+
false, 0.0,
13004+
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
13005+
il);
13006+
cb(moe_out, "ffn_moe_out", il);
13007+
13008+
// For Granite MoE Shared
13009+
if (hparams.n_ff_shexp > 0) {
13010+
ggml_tensor * ffn_shexp = build_ffn(cur,
13011+
model.layers[il].ffn_up_shexp, NULL, NULL,
13012+
model.layers[il].ffn_gate_shexp, NULL, NULL,
13013+
model.layers[il].ffn_down_shexp, NULL, NULL,
13014+
NULL,
13015+
LLM_FFN_SILU, LLM_FFN_PAR, il);
13016+
cb(ffn_shexp, "ffn_shexp", il);
13017+
13018+
cur = ggml_add(ctx0, moe_out, ffn_shexp);
13019+
cb(cur, "ffn_out", il);
13020+
} else {
13021+
cur = moe_out;
13022+
}
13023+
}
13024+
13025+
// For Granite architectures - scale residual
13026+
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
13027+
cur = ggml_add(ctx0, cur, ffn_inp);
13028+
cb(cur, "ffn_out", il);
13029+
13030+
cur = build_cvec(cur, il);
13031+
cb(cur, "l_out", il);
13032+
13033+
// input for next layer
13034+
inpL = cur;
13035+
}
13036+
13037+
cur = inpL;
13038+
13039+
cur = build_norm(cur,
13040+
model.output_norm, NULL,
13041+
LLM_NORM_RMS, -1);
13042+
13043+
cb(cur, "result_norm", -1);
13044+
res->t_embd = cur;
13045+
13046+
// lm_head
13047+
cur = build_lora_mm(model.output, cur);
13048+
13049+
// For Granite architectures - scale logits
13050+
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
13051+
cb(cur, "result_output", -1);
13052+
res->t_logits = cur;
13053+
13054+
ggml_build_forward_expand(gf, cur);
13055+
}
13056+
};
13057+
1278413058
// ref: https://github.com/facebookresearch/chameleon
1278513059
// based on the original build_llama() function, changes:
1278613060
// * qk-norm
@@ -13774,6 +14048,13 @@ llm_graph_result_ptr llama_model::build_graph(
1377414048
{
1377514049
llm = std::make_unique<llm_build_granite>(*this, params, gf);
1377614050
} break;
14051+
case LLM_ARCH_BAMBA:
14052+
{
14053+
llm = std::make_unique<llm_build_hybrid_mamba>(
14054+
*this, params, gf,
14055+
/* use_mamba2 */ true,
14056+
/* use_rope */ true);
14057+
} break;
1377714058
case LLM_ARCH_CHAMELEON:
1377814059
{
1377914060
llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
@@ -13922,6 +14203,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1392214203
case LLM_ARCH_GLM4:
1392314204
case LLM_ARCH_GRANITE:
1392414205
case LLM_ARCH_GRANITE_MOE:
14206+
case LLM_ARCH_BAMBA:
1392514207
case LLM_ARCH_CHAMELEON:
1392614208
case LLM_ARCH_BAILINGMOE:
1392714209
return LLAMA_ROPE_TYPE_NORM;

0 commit comments

Comments
 (0)