Skip to content

Commit ba73a0f

Browse files
committed
feat: Support GRANITE_MOE_HYBRID in llama-model
This re-uses the Bamba code paths heavily and simply adds the missing parts for loading MoE and the shared expert. Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent ef1630c commit ba73a0f

File tree

1 file changed

+40
-9
lines changed

1 file changed

+40
-9
lines changed

src/llama-model.cpp

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,6 +1436,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
14361436
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false);
14371437
} break;
14381438
case LLM_ARCH_BAMBA:
1439+
case LLM_ARCH_GRANITE_MOE_HYBRID:
14391440
{
14401441
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
14411442
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /* required */ false);
@@ -1477,6 +1478,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
14771478
// TODO: Add llm type label (not sure this is useful)
14781479
default: type = LLM_TYPE_UNKNOWN;
14791480
}
1481+
1482+
// For Granite MoE Shared
1483+
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false);
14801484
} break;
14811485
case LLM_ARCH_CHAMELEON:
14821486
{
@@ -3089,6 +3093,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
30893093
}
30903094
} break;
30913095
case LLM_ARCH_BAMBA:
3096+
case LLM_ARCH_GRANITE_MOE_HYBRID:
30923097
{
30933098
// mamba2 Mixer SSM params
30943099
// NOTE: int64_t for tensor dimensions
@@ -3155,14 +3160,31 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
31553160
}
31563161

31573162
// feed forward (w/ optional biases)
3158-
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
3159-
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
3160-
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
3161-
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
3162-
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
3163-
layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
3164-
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
3165-
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
3163+
if (n_expert > 0) {
3164+
// MoE FFN
3165+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
3166+
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
3167+
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
3168+
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED);
3169+
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
3170+
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
3171+
3172+
// For Granite MoE Shared
3173+
if (hparams.n_ff_shexp > 0) {
3174+
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
3175+
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
3176+
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0);
3177+
}
3178+
} else {
3179+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
3180+
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
3181+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
3182+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
3183+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
3184+
layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
3185+
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
3186+
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
3187+
}
31663188
}
31673189
} break;
31683190
case LLM_ARCH_XVERSE:
@@ -4611,7 +4633,9 @@ void llama_model::print_info() const {
46114633

46124634
if (arch == LLM_ARCH_MINICPM ||
46134635
arch == LLM_ARCH_GRANITE ||
4614-
arch == LLM_ARCH_GRANITE_MOE) {
4636+
arch == LLM_ARCH_GRANITE_MOE ||
4637+
arch == LLM_ARCH_GRANITE_MOE_HYBRID ||
4638+
arch == LLM_ARCH_BAMBA) {
46154639
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
46164640
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
46174641
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
@@ -14042,6 +14066,12 @@ llm_graph_result_ptr llama_model::build_graph(
1404214066
{
1404314067
llm = std::make_unique<llm_build_granite>(*this, params, gf);
1404414068
} break;
14069+
case LLM_ARCH_GRANITE_MOE_HYBRID:
14070+
{
14071+
llm = std::make_unique<llm_build_hybrid_mamba>(*this, params, gf,
14072+
/* use_mamba2 */ true,
14073+
/* use_rope */ false);
14074+
} break;
1404514075
case LLM_ARCH_BAMBA:
1404614076
{
1404714077
llm = std::make_unique<llm_build_hybrid_mamba>(
@@ -14197,6 +14227,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1419714227
case LLM_ARCH_GLM4:
1419814228
case LLM_ARCH_GRANITE:
1419914229
case LLM_ARCH_GRANITE_MOE:
14230+
case LLM_ARCH_GRANITE_MOE_HYBRID:
1420014231
case LLM_ARCH_BAMBA:
1420114232
case LLM_ARCH_CHAMELEON:
1420214233
case LLM_ARCH_BAILINGMOE:

0 commit comments

Comments
 (0)