@@ -1435,6 +1435,49 @@ void llama_model::load_hparams(llama_model_loader & ml) {
14351435 // For Granite MoE Shared
14361436 ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false);
14371437 } break;
1438+ case LLM_ARCH_BAMBA:
1439+ {
1440+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1441+ ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /* required */ false);
1442+ ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /* required */ false);
1443+ ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /* required */ false);
1444+ ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, /* required */ false);
1445+
1446+ ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
1447+ ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
1448+ ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
1449+ ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
1450+ ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
1451+
1452+ // Zero-out n_head_arr and n_head_kv_arr since SSM layers don't
1453+ // have attention heads. We'll set them correctly below once we
1454+ // know which layers are attention layers
1455+ // NOTE: It's important that this happens after n_embd_head_[kv]
1456+ // are set above!
1457+ const auto n_head_attn = hparams.n_head();
1458+ const auto n_head_kv_attn = hparams.n_head_kv();
1459+ std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
1460+ std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
1461+
1462+ // Attention params
1463+ std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true);
1464+ std::vector<uint32_t> attn_layer_indices;
1465+ ml.get_arr(LLM_KV_ATTENTION_LAYER_INDICES, attn_layer_indices);
1466+ for (const auto attn_idx : attn_layer_indices) {
1467+ GGML_ASSERT(attn_idx < hparams.n_layer);
1468+ hparams.recurrent_layer_arr[attn_idx] = false;
1469+ // Correctly set n_head and n_head_kv for attention layers
1470+ hparams.n_head_arr[attn_idx] = n_head_attn;
1471+ hparams.n_head_kv_arr[attn_idx] = n_head_kv_attn;
1472+ }
1473+
1474+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1475+
1476+ switch (hparams.n_layer) {
1477+ // TODO: Add llm type label (not sure this is useful)
1478+ default: type = LLM_TYPE_UNKNOWN;
1479+ }
1480+ } break;
14381481 case LLM_ARCH_CHAMELEON:
14391482 {
14401483 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -3049,6 +3092,83 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
30493092 layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
30503093 }
30513094 } break;
3095+ case LLM_ARCH_BAMBA:
3096+ {
3097+ // mamba2 Mixer SSM params
3098+ // NOTE: int64_t for tensor dimensions
3099+ const int64_t d_conv = hparams.ssm_d_conv;
3100+ const int64_t d_inner = hparams.ssm_d_inner;
3101+ const int64_t d_state = hparams.ssm_d_state;
3102+ const int64_t n_ssm_head = hparams.ssm_dt_rank;
3103+ const int64_t n_group = hparams.ssm_n_group;
3104+ const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head;
3105+
3106+ // only an expansion factor of 2 is supported for now
3107+ GGML_ASSERT(2 * n_embd == d_inner);
3108+
3109+ // embeddings
3110+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
3111+
3112+ // output
3113+ {
3114+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
3115+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
3116+ // if output is NULL, init from the input tok embed, duplicated to allow offloading
3117+ if (output == NULL) {
3118+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
3119+ }
3120+ }
3121+
3122+ for (int i = 0; i < n_layer; ++i) {
3123+ auto & layer = layers[i];
3124+
3125+ // norm
3126+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
3127+
3128+ if (hparams.recurrent_layer(i)) {
3129+ // ssm layers
3130+ layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0);
3131+ layer.ssm_in_b = create_tensor(tn(LLM_TENSOR_SSM_IN, "bias", i), {n_embd, d_in_proj}, llama_model_loader::TENSOR_NOT_REQUIRED);
3132+
3133+ layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
3134+ layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, llama_model_loader::TENSOR_NOT_REQUIRED);
3135+
3136+ layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0);
3137+
3138+ // no "weight" suffix for these
3139+ layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0);
3140+ layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0);
3141+
3142+ layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
3143+
3144+ // out_proj
3145+ layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
3146+ } else {
3147+ // attention layers (with optional bias)
3148+ const int64_t n_head_i = hparams.n_head(i);
3149+ const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i);
3150+ const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i);
3151+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0);
3152+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0);
3153+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0);
3154+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0);
3155+ layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
3156+ layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, llama_model_loader::TENSOR_NOT_REQUIRED);
3157+ layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, llama_model_loader::TENSOR_NOT_REQUIRED);
3158+ layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
3159+ }
3160+
3161+ // feed forward (w/ optional biases)
3162+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
3163+ layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
3164+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
3165+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
3166+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
3167+ layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
3168+ layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
3169+ layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
3170+ }
3171+ } break;
30523172 case LLM_ARCH_XVERSE:
30533173 {
30543174 tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -12781,6 +12901,160 @@ struct llm_build_granite : public llm_graph_context {
1278112901 }
1278212902};
1278312903
12904+ struct llm_build_hybrid_mamba : public llm_graph_context {
12905+
12906+ llm_build_hybrid_mamba(
12907+ const llama_model & model,
12908+ const llm_graph_params & params,
12909+ ggml_cgraph * gf,
12910+ const bool use_mamba2 = true,
12911+ const bool use_rope = true)
12912+ : llm_graph_context(params) {
12913+ const int64_t n_embd_head = hparams.n_embd_head_v;
12914+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
12915+
12916+ ggml_tensor * cur;
12917+ ggml_tensor * inpL;
12918+
12919+ inpL = build_inp_embd(model.tok_embd);
12920+
12921+ // Build the inputs in the recurrent cache
12922+ ggml_tensor * state_copy = build_inp_s_copy();
12923+
12924+ // Build the inputs in the attention cache
12925+ auto * inp_attn = build_attn_inp_kv_unified();
12926+
12927+ // Positional embeddings populated if rope enabled
12928+ ggml_tensor * inp_pos = nullptr;
12929+ if (use_rope) {
12930+ inp_pos = build_inp_pos();
12931+ }
12932+
12933+ // Extract the recurrent cache from the hybrid parent
12934+ const auto * kv_recurrent = static_cast<const llama_kv_cache_hybrid *>(memory)->get_child_cache<llama_kv_cache_recurrent>();
12935+ GGML_ASSERT(kv_recurrent);
12936+
12937+ for (int il = 0; il < n_layer; ++il) {
12938+ struct ggml_tensor * inpSA = inpL;
12939+
12940+ // norm
12941+ cur = build_norm(inpL,
12942+ model.layers[il].attn_norm, NULL,
12943+ LLM_NORM_RMS, il);
12944+ cb(cur, "attn_norm", il);
12945+
12946+ if (hparams.recurrent_layer(il)) {
12947+ // ssm layer //
12948+ if (use_mamba2) {
12949+ cur = llm_build_mamba::build_mamba2_layer(this, gf, cur, state_copy, kv_recurrent, model, ubatch, il);
12950+ } else {
12951+ cur = llm_build_mamba::build_mamba_layer(this, gf, cur, state_copy, kv_recurrent, model, ubatch, il);
12952+ }
12953+ } else {
12954+ // attention layer //
12955+ cur = llm_build_granite::build_attention_layer(
12956+ this, gf, cur, inp_pos, inp_attn,
12957+ model, n_embd_head, use_rope, il);
12958+ }
12959+
12960+ if (il == n_layer - 1) {
12961+ // skip computing output for unused tokens
12962+ ggml_tensor * inp_out_ids = build_inp_out_ids();
12963+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
12964+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
12965+ }
12966+
12967+ // For Granite architectures - scale residual
12968+ cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
12969+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
12970+ cb(ffn_inp, "ffn_inp", il);
12971+
12972+ // feed-forward network (non-MoE)
12973+ if (model.layers[il].ffn_gate_inp == nullptr) {
12974+
12975+ cur = build_norm(ffn_inp,
12976+ model.layers[il].ffn_norm, NULL,
12977+ LLM_NORM_RMS, il);
12978+ cb(cur, "ffn_norm", il);
12979+
12980+ cur = build_ffn(cur,
12981+ model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
12982+ model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
12983+ model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
12984+ NULL,
12985+ LLM_FFN_SILU, LLM_FFN_PAR, il);
12986+ cb(cur, "ffn_out", il);
12987+
12988+ } else {
12989+ // MoE branch
12990+ cur = build_norm(ffn_inp,
12991+ model.layers[il].ffn_norm, NULL,
12992+ LLM_NORM_RMS, il);
12993+ cb(cur, "ffn_norm", il);
12994+
12995+ ggml_tensor * moe_out = build_moe_ffn(cur,
12996+ model.layers[il].ffn_gate_inp,
12997+ model.layers[il].ffn_up_exps,
12998+ model.layers[il].ffn_gate_exps,
12999+ model.layers[il].ffn_down_exps,
13000+ nullptr,
13001+ n_expert, n_expert_used,
13002+ LLM_FFN_SILU, true,
13003+ false, 0.0,
13004+ LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
13005+ il);
13006+ cb(moe_out, "ffn_moe_out", il);
13007+
13008+ // For Granite MoE Shared
13009+ if (hparams.n_ff_shexp > 0) {
13010+ ggml_tensor * ffn_shexp = build_ffn(cur,
13011+ model.layers[il].ffn_up_shexp, NULL, NULL,
13012+ model.layers[il].ffn_gate_shexp, NULL, NULL,
13013+ model.layers[il].ffn_down_shexp, NULL, NULL,
13014+ NULL,
13015+ LLM_FFN_SILU, LLM_FFN_PAR, il);
13016+ cb(ffn_shexp, "ffn_shexp", il);
13017+
13018+ cur = ggml_add(ctx0, moe_out, ffn_shexp);
13019+ cb(cur, "ffn_out", il);
13020+ } else {
13021+ cur = moe_out;
13022+ }
13023+ }
13024+
13025+ // For Granite architectures - scale residual
13026+ cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
13027+ cur = ggml_add(ctx0, cur, ffn_inp);
13028+ cb(cur, "ffn_out", il);
13029+
13030+ cur = build_cvec(cur, il);
13031+ cb(cur, "l_out", il);
13032+
13033+ // input for next layer
13034+ inpL = cur;
13035+ }
13036+
13037+ cur = inpL;
13038+
13039+ cur = build_norm(cur,
13040+ model.output_norm, NULL,
13041+ LLM_NORM_RMS, -1);
13042+
13043+ cb(cur, "result_norm", -1);
13044+ res->t_embd = cur;
13045+
13046+ // lm_head
13047+ cur = build_lora_mm(model.output, cur);
13048+
13049+ // For Granite architectures - scale logits
13050+ cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
13051+ cb(cur, "result_output", -1);
13052+ res->t_logits = cur;
13053+
13054+ ggml_build_forward_expand(gf, cur);
13055+ }
13056+ };
13057+
1278413058// ref: https://github.com/facebookresearch/chameleon
1278513059// based on the original build_llama() function, changes:
1278613060// * qk-norm
@@ -13774,6 +14048,13 @@ llm_graph_result_ptr llama_model::build_graph(
1377414048 {
1377514049 llm = std::make_unique<llm_build_granite>(*this, params, gf);
1377614050 } break;
14051+ case LLM_ARCH_BAMBA:
14052+ {
14053+ llm = std::make_unique<llm_build_hybrid_mamba>(
14054+ *this, params, gf,
14055+ /* use_mamba2 */ true,
14056+ /* use_rope */ true);
14057+ } break;
1377714058 case LLM_ARCH_CHAMELEON:
1377814059 {
1377914060 llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
@@ -13922,6 +14203,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1392214203 case LLM_ARCH_GLM4:
1392314204 case LLM_ARCH_GRANITE:
1392414205 case LLM_ARCH_GRANITE_MOE:
14206+ case LLM_ARCH_BAMBA:
1392514207 case LLM_ARCH_CHAMELEON:
1392614208 case LLM_ARCH_BAILINGMOE:
1392714209 return LLAMA_ROPE_TYPE_NORM;
0 commit comments