Skip to content

Commit 985cda6

Browse files
committed
test-model-random : add Mamba2
1 parent 48a5eba commit 985cda6

File tree

1 file changed

+65
-3
lines changed

1 file changed

+65
-3
lines changed

tests/test-model-random.cpp

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,64 @@ struct model_variant {
628628
}
629629
}
630630
break;
631+
case LLM_ARCH_MAMBA2:
632+
{
633+
variants.push_back(model_variant(arch, "Mamba2"));
634+
model_variant & cur = variants.back();
635+
636+
n_embd = 64;
637+
638+
const uint32_t d_inner = 2 * n_embd;
639+
const uint32_t d_conv = 4;
640+
const uint32_t d_state = 128;
641+
const uint32_t n_group = 2;
642+
const uint32_t head_dim = 64;
643+
const uint32_t n_head = d_inner / head_dim;
644+
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head;
645+
646+
const auto init_A = [](std::mt19937 & rng) {
647+
return -std::uniform_real_distribution<float>(1, 16)(rng);
648+
};
649+
650+
cur.add_kv(LLM_KV_CONTEXT_LENGTH, (uint32_t) 1024 * 1024);
651+
cur.add_kv(LLM_KV_EMBEDDING_LENGTH, n_embd);
652+
cur.add_kv(LLM_KV_FEED_FORWARD_LENGTH, (uint32_t) 0);
653+
cur.add_kv(LLM_KV_ATTENTION_HEAD_COUNT, (uint32_t) 0);
654+
cur.add_kv(LLM_KV_BLOCK_COUNT, n_layer);
655+
cur.add_kv(LLM_KV_SSM_CONV_KERNEL, d_conv);
656+
cur.add_kv(LLM_KV_SSM_INNER_SIZE, d_inner);
657+
cur.add_kv(LLM_KV_SSM_STATE_SIZE, d_state);
658+
cur.add_kv(LLM_KV_SSM_TIME_STEP_RANK, n_head);
659+
cur.add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f);
660+
cur.add_kv(LLM_KV_SSM_GROUP_COUNT, n_group);
661+
662+
add_tokenizer(cur, n_vocab);
663+
664+
cur.add_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab });
665+
cur.add_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd });
666+
cur.add_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab });
667+
668+
for (uint32_t i = 0; i < n_layer; ++i) {
669+
cur.add_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
670+
671+
cur.add_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj});
672+
673+
cur.add_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state});
674+
cur.add_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, init_bias);
675+
676+
cur.add_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, init_bias);
677+
678+
// no "weight" suffix for these
679+
cur.add_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, init_A);
680+
cur.add_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, []() { return 1.0f; });
681+
682+
cur.add_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group});
683+
684+
// out_proj
685+
cur.add_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
686+
}
687+
}
688+
break;
631689
case LLM_ARCH_XVERSE:
632690
case LLM_ARCH_COMMAND_R:
633691
case LLM_ARCH_COHERE2:
@@ -760,6 +818,7 @@ struct model_variant {
760818
case LLM_ARCH_BAILINGMOE:
761819
case LLM_ARCH_DOTS1:
762820
case LLM_ARCH_ARCEE:
821+
case LLM_ARCH_ERNIE4_5:
763822
case LLM_ARCH_UNKNOWN:
764823
break;
765824
}
@@ -1042,6 +1101,8 @@ int main(int argc, char ** argv) {
10421101

10431102
std::vector<reference_logits> ref_outputs;
10441103

1104+
fprintf(stdout, "Generating reference outputs for '%s', n_seq_max=%i...\n", variant.name.c_str(), n_seq_max);
1105+
10451106
{
10461107
llama_context_params ref_params = llama_context_default_params();
10471108
ref_params.n_batch = n_seq_len;
@@ -1092,6 +1153,10 @@ int main(int argc, char ** argv) {
10921153

10931154
float max_err = 0.0f;
10941155

1156+
fprintf(stdout,
1157+
"Comparing output for '%s', with shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ",
1158+
variant.name.c_str(), shuffle, n_seq_max, n_ctx, n_ubatch);
1159+
10951160
// start filling the batch with prompts
10961161
while (std::any_of(seq_id_n_past.begin(), seq_id_n_past.end(),
10971162
[](llama_pos p) { return p < n_seq_len; })) {
@@ -1140,9 +1205,6 @@ int main(int argc, char ** argv) {
11401205
}
11411206
}
11421207

1143-
fprintf(stdout,
1144-
"Comparing output for '%s', with shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ",
1145-
variant.name.c_str(), shuffle, n_seq_max, n_ctx, n_ubatch);
11461208
if (std::all_of(valid.begin(), valid.end(), [](bool v) { return v; })) {
11471209
fprintf(stdout, "\033[1;32mOK\033[0m (max err: %.2g)\n", max_err);
11481210
} else {

0 commit comments

Comments
 (0)