@@ -628,6 +628,64 @@ struct model_variant {
628628 }
629629 }
630630 break ;
631+ case LLM_ARCH_MAMBA2:
632+ {
633+ variants.push_back (model_variant (arch, " Mamba2" ));
634+ model_variant & cur = variants.back ();
635+
636+ n_embd = 64 ;
637+
638+ const uint32_t d_inner = 2 * n_embd;
639+ const uint32_t d_conv = 4 ;
640+ const uint32_t d_state = 128 ;
641+ const uint32_t n_group = 2 ;
642+ const uint32_t head_dim = 64 ;
643+ const uint32_t n_head = d_inner / head_dim;
644+ const int64_t d_in_proj = 2 *d_inner + 2 *n_group*d_state + n_head;
645+
646+ const auto init_A = [](std::mt19937 & rng) {
647+ return -std::uniform_real_distribution<float >(1 , 16 )(rng);
648+ };
649+
650+ cur.add_kv (LLM_KV_CONTEXT_LENGTH, (uint32_t ) 1024 * 1024 );
651+ cur.add_kv (LLM_KV_EMBEDDING_LENGTH, n_embd);
652+ cur.add_kv (LLM_KV_FEED_FORWARD_LENGTH, (uint32_t ) 0 );
653+ cur.add_kv (LLM_KV_ATTENTION_HEAD_COUNT, (uint32_t ) 0 );
654+ cur.add_kv (LLM_KV_BLOCK_COUNT, n_layer);
655+ cur.add_kv (LLM_KV_SSM_CONV_KERNEL, d_conv);
656+ cur.add_kv (LLM_KV_SSM_INNER_SIZE, d_inner);
657+ cur.add_kv (LLM_KV_SSM_STATE_SIZE, d_state);
658+ cur.add_kv (LLM_KV_SSM_TIME_STEP_RANK, n_head);
659+ cur.add_kv (LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f );
660+ cur.add_kv (LLM_KV_SSM_GROUP_COUNT, n_group);
661+
662+ add_tokenizer (cur, n_vocab);
663+
664+ cur.add_tensor (tn (LLM_TENSOR_TOKEN_EMBD, " weight" ), { n_embd, n_vocab });
665+ cur.add_tensor (tn (LLM_TENSOR_OUTPUT_NORM, " weight" ), { n_embd });
666+ cur.add_tensor (tn (LLM_TENSOR_OUTPUT, " weight" ), { n_embd, n_vocab });
667+
668+ for (uint32_t i = 0 ; i < n_layer; ++i) {
669+ cur.add_tensor (tn (LLM_TENSOR_ATTN_NORM, " weight" , i), {n_embd});
670+
671+ cur.add_tensor (tn (LLM_TENSOR_SSM_IN, " weight" , i), {n_embd, d_in_proj});
672+
673+ cur.add_tensor (tn (LLM_TENSOR_SSM_CONV1D, " weight" , i), {d_conv, d_inner + 2 *n_group*d_state});
674+ cur.add_tensor (tn (LLM_TENSOR_SSM_CONV1D, " bias" , i), {d_inner + 2 *n_group*d_state}, init_bias);
675+
676+ cur.add_tensor (tn (LLM_TENSOR_SSM_DT, " bias" , i), {n_head}, init_bias);
677+
678+ // no "weight" suffix for these
679+ cur.add_tensor (tn (LLM_TENSOR_SSM_A, i), {1 , n_head}, init_A);
680+ cur.add_tensor (tn (LLM_TENSOR_SSM_D, i), {1 , n_head}, []() { return 1 .0f ; });
681+
682+ cur.add_tensor (tn (LLM_TENSOR_SSM_NORM, " weight" , i), {d_inner / n_group, n_group});
683+
684+ // out_proj
685+ cur.add_tensor (tn (LLM_TENSOR_SSM_OUT, " weight" , i), {d_inner, n_embd});
686+ }
687+ }
688+ break ;
631689 case LLM_ARCH_XVERSE:
632690 case LLM_ARCH_COMMAND_R:
633691 case LLM_ARCH_COHERE2:
@@ -760,6 +818,7 @@ struct model_variant {
760818 case LLM_ARCH_BAILINGMOE:
761819 case LLM_ARCH_DOTS1:
762820 case LLM_ARCH_ARCEE:
821+ case LLM_ARCH_ERNIE4_5:
763822 case LLM_ARCH_UNKNOWN:
764823 break ;
765824 }
@@ -1042,6 +1101,8 @@ int main(int argc, char ** argv) {
10421101
10431102 std::vector<reference_logits> ref_outputs;
10441103
1104+ fprintf (stdout, " Generating reference outputs for '%s', n_seq_max=%i...\n " , variant.name .c_str (), n_seq_max);
1105+
10451106 {
10461107 llama_context_params ref_params = llama_context_default_params ();
10471108 ref_params.n_batch = n_seq_len;
@@ -1092,6 +1153,10 @@ int main(int argc, char ** argv) {
10921153
10931154 float max_err = 0 .0f ;
10941155
1156+ fprintf (stdout,
1157+ " Comparing output for '%s', with shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: " ,
1158+ variant.name .c_str (), shuffle, n_seq_max, n_ctx, n_ubatch);
1159+
10951160 // start filling the batch with prompts
10961161 while (std::any_of (seq_id_n_past.begin (), seq_id_n_past.end (),
10971162 [](llama_pos p) { return p < n_seq_len; })) {
@@ -1140,9 +1205,6 @@ int main(int argc, char ** argv) {
11401205 }
11411206 }
11421207
1143- fprintf (stdout,
1144- " Comparing output for '%s', with shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: " ,
1145- variant.name .c_str (), shuffle, n_seq_max, n_ctx, n_ubatch);
11461208 if (std::all_of (valid.begin (), valid.end (), [](bool v) { return v; })) {
11471209 fprintf (stdout, " \033 [1;32mOK\033 [0m (max err: %.2g)\n " , max_err);
11481210 } else {
0 commit comments