@@ -328,6 +328,7 @@ enum llm_kv {
328328    LLM_KV_SSM_CONV_KERNEL,
329329    LLM_KV_SSM_STATE_SIZE,
330330    LLM_KV_SSM_TIME_STEP_RANK,
331+     LLM_KV_SSM_DT_B_C_RMS,
331332
332333    LLM_KV_TOKENIZER_MODEL,
333334    LLM_KV_TOKENIZER_PRE,
@@ -426,6 +427,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
426427    { LLM_KV_SSM_INNER_SIZE,                "%s.ssm.inner_size"     },
427428    { LLM_KV_SSM_STATE_SIZE,                "%s.ssm.state_size"     },
428429    { LLM_KV_SSM_TIME_STEP_RANK,            "%s.ssm.time_step_rank" },
430+     { LLM_KV_SSM_DT_B_C_RMS,                "%s.ssm.dt_b_c_rms" },
429431
430432    { LLM_KV_TOKENIZER_MODEL,                "tokenizer.ggml.model"                    },
431433    { LLM_KV_TOKENIZER_PRE,                  "tokenizer.ggml.pre"                      },
@@ -2237,6 +2239,7 @@ struct llama_hparams {
22372239    uint32_t ssm_d_inner = 0;
22382240    uint32_t ssm_d_state = 0;
22392241    uint32_t ssm_dt_rank = 0;
2242+     bool ssm_dt_b_c_rms = false;
22402243
22412244    float f_clamp_kqv      = 0.0f;
22422245    float f_max_alibi_bias = 0.0f;
@@ -2286,6 +2289,7 @@ struct llama_hparams {
22862289        if (this->ssm_d_inner != other.ssm_d_inner) return true;
22872290        if (this->ssm_d_state != other.ssm_d_state) return true;
22882291        if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
2292+         if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;
22892293
22902294        if (this->dec_start_token_id != other.dec_start_token_id) return true;
22912295
@@ -5052,6 +5056,7 @@ static void llm_load_hparams(
50525056                ml.get_key(LLM_KV_SSM_INNER_SIZE,     hparams.ssm_d_inner);
50535057                ml.get_key(LLM_KV_SSM_STATE_SIZE,     hparams.ssm_d_state);
50545058                ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
5059+                 ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false);
50555060
50565061                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
50575062
@@ -5907,6 +5912,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
59075912        LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
59085913        LLAMA_LOG_INFO("%s: ssm_d_state      = %u\n",     __func__, hparams.ssm_d_state);
59095914        LLAMA_LOG_INFO("%s: ssm_dt_rank      = %u\n",     __func__, hparams.ssm_dt_rank);
5915+         LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms   = %d\n",     __func__, hparams.ssm_dt_b_c_rms);
59105916    }
59115917
59125918    LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, llama_model_type_name(model.type));
@@ -12165,6 +12171,10 @@ struct llm_build_context {
1216512171        GGML_ASSERT(2 * d_model == d_inner);
1216612172        const int64_t d_state = hparams.ssm_d_state;
1216712173        const int64_t dt_rank = hparams.ssm_dt_rank;
12174+         // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
12175+         const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
12176+         // Use the same RMS norm as the final layer norm
12177+         const float norm_rms_eps = hparams.f_norm_rms_eps;
1216812178
1216912179        struct ggml_tensor * cur;
1217012180        struct ggml_tensor * inpL;
@@ -12245,6 +12255,13 @@ struct llm_build_context {
1224512255                struct ggml_tensor * B  = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
1224612256                struct ggml_tensor * C  = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
1224712257
12258+                 // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
12259+                 if (ssm_dt_b_c_rms) {
12260+                     dt = ggml_rms_norm(ctx0, dt, norm_rms_eps);
12261+                     B = ggml_rms_norm(ctx0, B, norm_rms_eps);
12262+                     C = ggml_rms_norm(ctx0, C, norm_rms_eps);
12263+                 }
12264+ 
1224812265                // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
1224912266                dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt);
1225012267                dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
@@ -16109,6 +16126,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
1610916126            case GGML_TYPE_Q6_K:   new_type = GGML_TYPE_Q8_0;   break;
1611016127            default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
1611116128        }
16129+         if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
16130+             new_type = GGML_TYPE_F16;
16131+         }
1611216132        LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
1611316133        ++qs.n_fallback;
1611416134    }
@@ -16437,8 +16457,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
1643716457        // do not quantize Mamba's small yet 2D weights
1643816458        // NOTE: can't use LLM_TN here because the layer number is not known
1643916459        quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
16440-         quantize &= name.find("ssm_x.weight")      == std::string::npos;
16441-         quantize &= name.find("ssm_dt.weight")     == std::string::npos;
1644216460
1644316461        // do not quantize relative position bias (T5)
1644416462        quantize &= name.find("attn_rel_b.weight") == std::string::npos;
0 commit comments