1212
1313/* - Helpers ------------------------------------------------------------------*/
1414
15- static std::shared_ptr<llama_model> _make_model () {
15+ static std::shared_ptr<llama_model> _make_model (
16+ llm_arch arch = LLM_ARCH_LLAMA,
17+ uint32_t n_layer = 4 ,
18+ uint32_t n_embd_head_k = 4 ,
19+ uint32_t n_embd_head_v = 4 ,
20+ uint32_t n_head = 8 ,
21+ uint32_t n_head_kv = 2 ) {
22+
1623 llama_model_params params;
1724 params.tensor_buft_overrides = nullptr ;
1825 std::shared_ptr<llama_model> model (new llama_model (params));
1926 model->hparams = llama_hparams ();
20- model->arch = LLM_ARCH_LLAMA;
27+ model->arch = arch;
28+
29+ model->hparams .n_layer = n_layer;
30+ model->hparams .n_embd_head_k = n_embd_head_k;
31+ model->hparams .n_embd_head_v = n_embd_head_v;
32+
33+ auto & recurrent_layer_arr = model->hparams .recurrent_layer_arr ;
34+ std::fill (
35+ recurrent_layer_arr.begin (),
36+ recurrent_layer_arr.end (),
37+ llm_arch_is_recurrent (arch));
38+
39+ // If set to 0, assume the test will fill out the array elementwise (hybrid)
40+ if (n_head > 0 ) {
41+ auto & n_head_arr = model->hparams .n_head_arr ;
42+ std::fill (n_head_arr.begin (), n_head_arr.end (), n_head);
43+ }
44+ if (n_head_kv > 0 ) {
45+ auto & n_head_kv_arr = model->hparams .n_head_kv_arr ;
46+ std::fill (n_head_kv_arr.begin (), n_head_kv_arr.end (), n_head_kv);
47+ }
48+
2149 return model;
2250}
2351
@@ -57,7 +85,7 @@ static void test_llama_kv_cache_unified_constructor() {
5785/* Test that the recurrent cache can be constructed and destructed safely */
5886static void test_llama_kv_cache_recurrent_constructor () {
5987 LOG_SCOPE ();
60- auto model = _make_model ();
88+ auto model = _make_model (LLM_ARCH_MAMBA );
6189 llama_kv_cache_recurrent cache (
6290 /* model */ *model,
6391 /* type_k */ GGML_TYPE_F32,
@@ -72,15 +100,24 @@ static void test_llama_kv_cache_recurrent_constructor() {
72100/* Test that the hybrid cache can be constructed and destructed safely */
73101static void test_llama_kv_cache_hybrid_constructor () {
74102 LOG_SCOPE ();
75- auto model = _make_model ();
76- model->hparams .n_layer = 4 ;
77- model->hparams .n_embd_head_k = 4 ;
78- model->hparams .n_embd_head_v = 4 ;
103+ auto model = _make_model (
104+ /* arch =*/ LLM_ARCH_LLAMA,
105+ /* n_layer =*/ 4 ,
106+ /* n_embd_head_k =*/ 4 ,
107+ /* n_embd_head_v =*/ 4 ,
108+ /* n_head =*/ 0 ,
109+ /* n_head_kv =*/ 0
110+ );
79111 auto & recurrent_layer_arr = model->hparams .recurrent_layer_arr ;
80112 recurrent_layer_arr[0 ] = 1 ;
81113 recurrent_layer_arr[1 ] = 0 ;
82114 recurrent_layer_arr[2 ] = 1 ;
83115 recurrent_layer_arr[3 ] = 0 ;
116+ auto & n_head_arr = model->hparams .n_head_arr ;
117+ n_head_arr[0 ] = 16 ;
118+ n_head_arr[1 ] = 32 ;
119+ n_head_arr[2 ] = 16 ;
120+ n_head_arr[3 ] = 32 ;
84121 auto & n_head_kv_arr = model->hparams .n_head_kv_arr ;
85122 n_head_kv_arr[0 ] = 16 ;
86123 n_head_kv_arr[1 ] = 8 ;
0 commit comments