| 
 | 1 | +#include "../src/llama-arch.h"  | 
 | 2 | +#include "../src/llama-hparams.h"  | 
 | 3 | +#include "../src/llama-impl.h"  | 
 | 4 | +#include "../src/llama-kv-cache.h"  | 
 | 5 | +#include "../src/llama-model.h"  | 
 | 6 | + | 
 | 7 | +#include "llama.h"  | 
 | 8 | + | 
 | 9 | +#include <algorithm>  | 
 | 10 | +#include <cstdio>  | 
 | 11 | +#include <memory>  | 
 | 12 | + | 
 | 13 | +/*- Helpers ------------------------------------------------------------------*/  | 
 | 14 | + | 
 | 15 | +static std::shared_ptr<llama_model> _make_model() {  | 
 | 16 | +    llama_model_params params;  | 
 | 17 | +    params.tensor_buft_overrides = nullptr;  | 
 | 18 | +    std::shared_ptr<llama_model> model(new llama_model(params));  | 
 | 19 | +    model->hparams = llama_hparams();  | 
 | 20 | +    model->arch = LLM_ARCH_LLAMA;  | 
 | 21 | +    return model;  | 
 | 22 | +}  | 
 | 23 | + | 
 | 24 | +struct log_scope {  | 
 | 25 | +    const char * name;  | 
 | 26 | +    explicit log_scope(const char * name) : name(name) {  | 
 | 27 | +        LLAMA_LOG_INFO("--------\n");  | 
 | 28 | +        LLAMA_LOG_INFO("START: %s\n", name);  | 
 | 29 | +    }  | 
 | 30 | +    ~log_scope() {  | 
 | 31 | +        LLAMA_LOG_INFO("END: %s\n", name);  | 
 | 32 | +        LLAMA_LOG_INFO("--------\n");  | 
 | 33 | +    }  | 
 | 34 | +};  | 
 | 35 | + | 
 | 36 | +#define LOG_SCOPE() log_scope __log_scope(__func__)  | 
 | 37 | + | 
 | 38 | +/*- Unified Cache ------------------------------------------------------------*/  | 
 | 39 | + | 
 | 40 | +/* Test that the unified cache can be constructed and destructed safely */  | 
 | 41 | +static void test_llama_kv_cache_unified_constructor() {  | 
 | 42 | +    LOG_SCOPE();  | 
 | 43 | +    auto model = _make_model();  | 
 | 44 | +    llama_kv_cache_unified cache(  | 
 | 45 | +        /* model   */ *model,  | 
 | 46 | +        /* type_k  */ GGML_TYPE_F32,  | 
 | 47 | +        /* type_v  */ GGML_TYPE_F16,  | 
 | 48 | +        /* v_trans */ false,  | 
 | 49 | +        /* offload */ false,  | 
 | 50 | +        /* kv_size */ 10,  | 
 | 51 | +        /* padding */ 10  | 
 | 52 | +    );  | 
 | 53 | +}  | 
 | 54 | + | 
 | 55 | +/*- Recurrent Cache ----------------------------------------------------------*/  | 
 | 56 | + | 
 | 57 | +/* Test that the recurrent cache can be constructed and destructed safely */  | 
 | 58 | +static void test_llama_kv_cache_recurrent_constructor() {  | 
 | 59 | +    LOG_SCOPE();  | 
 | 60 | +    auto model = _make_model();  | 
 | 61 | +    llama_kv_cache_recurrent cache(  | 
 | 62 | +        /* model   */ *model,  | 
 | 63 | +        /* type_k  */ GGML_TYPE_F32,  | 
 | 64 | +        /* type_v  */ GGML_TYPE_F16,  | 
 | 65 | +        /* offload */ false,  | 
 | 66 | +        /* kv_size */ 10  | 
 | 67 | +    );  | 
 | 68 | +}  | 
 | 69 | + | 
 | 70 | +/*- Hybrid Cache -------------------------------------------------------------*/  | 
 | 71 | + | 
 | 72 | +/* Test that the hybrid cache can be constructed and destructed safely */  | 
 | 73 | +static void test_llama_kv_cache_hybrid_constructor() {  | 
 | 74 | +    LOG_SCOPE();  | 
 | 75 | +    auto model = _make_model();  | 
 | 76 | +    model->hparams.n_layer = 4;  | 
 | 77 | +    model->hparams.n_embd_head_k = 4;  | 
 | 78 | +    model->hparams.n_embd_head_v = 4;  | 
 | 79 | +    auto& recurrent_layer_arr = model->hparams.recurrent_layer_arr;  | 
 | 80 | +    recurrent_layer_arr[0] = 1;  | 
 | 81 | +    recurrent_layer_arr[1] = 0;  | 
 | 82 | +    recurrent_layer_arr[2] = 1;  | 
 | 83 | +    recurrent_layer_arr[3] = 0;  | 
 | 84 | +    auto& n_head_kv_arr = model->hparams.n_head_kv_arr;  | 
 | 85 | +    n_head_kv_arr[0] = 16;  | 
 | 86 | +    n_head_kv_arr[1] = 8;  | 
 | 87 | +    n_head_kv_arr[2] = 16;  | 
 | 88 | +    n_head_kv_arr[3] = 8;  | 
 | 89 | + | 
 | 90 | +    std::unique_ptr<llama_kv_cache_unified> u_cache(  | 
 | 91 | +        new llama_kv_cache_unified(  | 
 | 92 | +            /* model   */ *model,  | 
 | 93 | +            /* type_k  */ GGML_TYPE_F32,  | 
 | 94 | +            /* type_v  */ GGML_TYPE_F16,  | 
 | 95 | +            /* v_trans */ false,  | 
 | 96 | +            /* offload */ false,  | 
 | 97 | +            /* kv_size */ 20,  | 
 | 98 | +            /* padding */ 2  | 
 | 99 | +        )  | 
 | 100 | +    );  | 
 | 101 | +    auto * u_cache_ptr = u_cache.get();  | 
 | 102 | +    std::unique_ptr<llama_kv_cache_recurrent> r_cache (  | 
 | 103 | +        new llama_kv_cache_recurrent(  | 
 | 104 | +            /* model   */ *model,  | 
 | 105 | +            /* type_k  */ GGML_TYPE_F32,  | 
 | 106 | +            /* type_v  */ GGML_TYPE_F16,  | 
 | 107 | +            /* offload */ false,  | 
 | 108 | +            /* kv_size */ 10  | 
 | 109 | +        )  | 
 | 110 | +    );  | 
 | 111 | +    auto * r_cache_ptr = r_cache.get();  | 
 | 112 | + | 
 | 113 | +    std::vector<llama_kv_cache_hybrid::child_cache> children;  | 
 | 114 | +    children.emplace_back(std::move(u_cache), std::vector<size_t>{1, 3});  | 
 | 115 | +    children.emplace_back(std::move(r_cache), std::vector<size_t>{0, 2});  | 
 | 116 | + | 
 | 117 | +    llama_kv_cache_hybrid cache(model->hparams, std::move(children));  | 
 | 118 | + | 
 | 119 | +    GGML_ASSERT(cache.get_child_cache<llama_kv_cache_unified>() == u_cache_ptr);  | 
 | 120 | +    GGML_ASSERT(cache.get_child_cache<llama_kv_cache_recurrent>() == r_cache_ptr);  | 
 | 121 | +}  | 
 | 122 | + | 
 | 123 | +/*- Main ---------------------------------------------------------------------*/  | 
 | 124 | + | 
 | 125 | +int main() {  | 
 | 126 | +    // Unified Cache Tests  | 
 | 127 | +    test_llama_kv_cache_unified_constructor();  | 
 | 128 | +    // Recurrent Cache Tests  | 
 | 129 | +    test_llama_kv_cache_recurrent_constructor();  | 
 | 130 | +    // Hybrid Cache Tests  | 
 | 131 | +    test_llama_kv_cache_hybrid_constructor();  | 
 | 132 | +    return 0;  | 
 | 133 | +}  | 
0 commit comments