From 93864cda8a024227dec297085f6152662738ea0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Wed, 22 Jan 2025 15:19:34 +0100 Subject: [PATCH 01/17] llama : experimental DeepSeek2 MLA implementation that caches latent kv representations --- src/llama-kv-cache.cpp | 16 ++++++- src/llama-kv-cache.h | 7 +++ src/llama.cpp | 99 +++++++++++++++++++++++++++++++++++------- 3 files changed, 106 insertions(+), 16 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 90b6c56ed068c..99fd1d8df1c32 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -53,7 +53,7 @@ bool llama_kv_cache_init( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { struct ggml_init_params params = { - /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(4u*n_layer*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -71,6 +71,10 @@ bool llama_kv_cache_init( cache.k_l.reserve(n_layer); cache.v_l.reserve(n_layer); + // DeepSeek MLA + cache.kr_l.reserve(n_layer); + cache.kv_l.reserve(n_layer); + for (int i = 0; i < n_layer; i++) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); @@ -97,6 +101,16 @@ bool llama_kv_cache_init( ggml_format_name(v, "cache_v_l%d", i); cache.k_l.push_back(k); cache.v_l.push_back(v); + + // DeepSeek MLA + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size); + ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size); + ggml_format_name(kr, "cache_kr_l%d", i); + ggml_format_name(kv, "cache_kv_l%d", i); + cache.kr_l.push_back(kr); + cache.kv_l.push_back(kv); } // allocate tensors and initialize the buffers to avoid NaNs in the padding diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index dca6f3998c645..7f2e1b3e7b144 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -49,11 +49,18 @@ struct llama_kv_cache { ggml_type type_k = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16; + ggml_type type_kr = GGML_TYPE_F32; + ggml_type type_kv = GGML_TYPE_F32; + std::vector cells; std::vector k_l; // per layer std::vector v_l; + // DeepSeek MLA + std::vector kr_l; // per layer + std::vector kv_l; + std::vector ctxs; std::vector bufs; diff --git a/src/llama.cpp b/src/llama.cpp index 60728e5bb91ca..99af190e1474b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8860,32 +8860,37 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(kv_compressed, "kv_compressed", il); + struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head); + cb(kv_cache_view, "kv_cache_view", il); + + // note: storing c^KV in the KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view)); + + struct ggml_tensor * kv_cache = + ggml_view_2d(ctx0, kv_self.kv_l[il], + kv_lora_rank, n_kv, + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank), + 0); + cb(kv_cache, "kv_cache", il); + // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} - struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); + struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache); cb(kv, "kv", il); // split into {n_head * n_embd_head_qk_nope, n_tokens} - struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_kv, ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), 0); cb(k_nope, "k_nope", il); // and {n_head * n_embd_head_v, n_tokens} - struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, + struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_kv, ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), ggml_row_size(kv->type, (n_embd_head_qk_nope))); cb(v_states, "v_states", il); - v_states = ggml_cont(ctx0, v_states); - cb(v_states, "v_states", il); - - v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), - 0); - cb(v_states, "v_states", il); - q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE q_pe = ggml_rope_ext( ctx0, q_pe, inp_pos, nullptr, @@ -8903,15 +8908,61 @@ struct llm_build_context { ); cb(k_pe, "k_pe", il); + struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head); + cb(kr_cache_view, "kr_cache_view", il); + + // note: storing RoPE-ed version of K^R in the KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view)); + struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); cb(q_states, "q_states", il); - struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); + struct ggml_tensor * kr_cache = + ggml_view_2d(ctx0, kv_self.kr_l[il], + n_embd_head_qk_rope, n_kv, + ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope), + 0); + cb(kr_cache, "kr_cache", il); + + // TODO is there a better way? + struct ggml_tensor * kr_rep_shape = ggml_new_tensor_3d(ctx0, kr_cache->type, kr_cache->ne[0], kr_cache->ne[1], n_head); + struct ggml_tensor * kr_rep = ggml_repeat(ctx0, kr_cache, kr_rep_shape); + kr_rep = ggml_permute(ctx0, kr_rep, 0, 2, 1, 3); + struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, kr_rep, 0); cb(k_states, "k_states", il); - cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, - k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + q_states = ggml_permute(ctx0, q_states, 0, 2, 1, 3); + cb(q_states, "q_states", il); + + k_states = ggml_permute(ctx0, k_states, 0, 2, 1, 3); + cb(k_states, "k_states", il); + + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k_states, q_states); + cb(kq, "kq", il); + + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + v_states = ggml_permute(ctx0, v_states, 1, 2, 0, 3); + cb(v_states, "v_states", il); + + v_states = ggml_cont(ctx0, v_states); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v_states, kq); + cb(kqv, "kqv", il); + + GGML_ASSERT(kv_self.size == n_ctx); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); + + ggml_build_forward_expand(gf, cur); + + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); + cb(cur, "kqv_out", il); } if (il == n_layer - 1) { @@ -12004,6 +12055,24 @@ struct llama_context * llama_new_context_with_model( ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } + { + size_t memory_size_kr = 0; + size_t memory_size_kv = 0; + + for (auto & kr : ctx->kv_self.kr_l) { + memory_size_kr += ggml_nbytes(kr); + } + + for (auto & kv : ctx->kv_self.kv_l) { + memory_size_kv += ggml_nbytes(kv); + } + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n", __func__, + (float)(memory_size_kr + memory_size_kv) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_kv / (1024.0f * 1024.0f)); + } + // graph outputs buffer { // resized during inference when a batch uses more outputs From f07c2ec505f2ba93c3ec8246b258a9c97c7c1660 Mon Sep 17 00:00:00 2001 From: slaren Date: Fri, 24 Jan 2025 20:56:09 +0100 Subject: [PATCH 02/17] llama : add option to override tensor buffers --- common/arg.cpp | 38 ++++++++++++++++++++++++++++++++++++++ common/common.cpp | 10 ++++++++++ common/common.h | 1 + include/llama.h | 8 ++++++++ src/llama-model-loader.cpp | 5 ++++- src/llama-model-loader.h | 8 +++++--- src/llama-model.cpp | 21 +++++++++++++++++++-- src/llama-quant.cpp | 2 +- src/llama.cpp | 2 +- 9 files changed, 87 insertions(+), 8 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index a6226a34b1860..d746f832e541d 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1,5 +1,6 @@ #include "arg.h" +#include "common.h" #include "log.h" #include "sampling.h" @@ -321,6 +322,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context params.kv_overrides.back().key[0] = 0; } + if (!params.tensor_buft_overrides.empty()) { + params.tensor_buft_overrides.push_back({nullptr, nullptr}); + } + if (params.reranking && params.embedding) { throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both"); } @@ -1477,6 +1482,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex exit(0); } )); + add_opt(common_arg( + {"--override-tensor", "-ot"}, "=,...", + "override tensor buffer type", [](common_params & params, const std::string & value) { + static std::map buft_list; + if (buft_list.empty()) { + // enumerate all the devices and add their buffer types to the list + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + auto * dev = ggml_backend_dev_get(i); + auto * buft = ggml_backend_dev_buffer_type(dev); + buft_list[ggml_backend_buft_name(buft)] = buft; + } + } + + for (const auto & override : string_split(value, ',')) { + std::string::size_type pos = override.find('='); + if (pos == std::string::npos) { + throw std::invalid_argument("invalid value"); + } + std::string tensor_name = override.substr(0, pos); + std::string buffer_type = override.substr(pos + 1); + + if (buft_list.find(buffer_type) == buft_list.end()) { + printf("Available buffer types:\n"); + for (const auto & it : buft_list) { + printf(" %s\n", ggml_backend_buft_name(it.second)); + } + throw std::invalid_argument("unknown buffer type"); + } + // FIXME: this leaks memory + params.tensor_buft_overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)}); + } + } + )); add_opt(common_arg( {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N", "number of layers to store in VRAM", diff --git a/common/common.cpp b/common/common.cpp index 6dea8e3d25238..1af628625ffe1 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1083,15 +1083,18 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { if (!params.devices.empty()) { mparams.devices = params.devices.data(); } + if (params.n_gpu_layers != -1) { mparams.n_gpu_layers = params.n_gpu_layers; } + mparams.main_gpu = params.main_gpu; mparams.split_mode = params.split_mode; mparams.tensor_split = params.tensor_split; mparams.use_mmap = params.use_mmap; mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; + if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; } else { @@ -1099,6 +1102,13 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.kv_overrides = params.kv_overrides.data(); } + if (params.tensor_buft_overrides.empty()) { + mparams.tensor_buft_overrides = NULL; + } else { + GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern"); + mparams.tensor_buft_overrides = params.tensor_buft_overrides.data(); + } + return mparams; } diff --git a/common/common.h b/common/common.h index 571260372090f..9b42a8944d618 100644 --- a/common/common.h +++ b/common/common.h @@ -256,6 +256,7 @@ struct common_params { std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; + std::vector tensor_buft_overrides; bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply) std::vector lora_adapters; // lora adapter path with user defined scale diff --git a/include/llama.h b/include/llama.h index 3b75e760780ef..26c6dd12828c5 100644 --- a/include/llama.h +++ b/include/llama.h @@ -275,10 +275,18 @@ extern "C" { }; }; + struct llama_model_tensor_buft_override { + const char * pattern; + ggml_backend_buffer_type_t buft; + }; + struct llama_model_params { // NULL-terminated list of devices to use for offloading (if NULL, all available devices are used) ggml_backend_dev_t * devices; + // NULL-terminated list of buffer types to use for tensors that match a pattern + const struct llama_model_tensor_buft_override * tensor_buft_overrides; + int32_t n_gpu_layers; // number of layers to store in VRAM enum llama_split_mode split_mode; // how to split the model across multiple GPUs diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 75073bf610ac3..c64e974a94f57 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -445,7 +445,8 @@ llama_model_loader::llama_model_loader( std::vector & splits, bool use_mmap, bool check_tensors, - const struct llama_model_kv_override * param_overrides_p) { + const llama_model_kv_override * param_overrides_p, + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -457,6 +458,8 @@ llama_model_loader::llama_model_loader( } } + tensor_buft_overrides = param_tensor_buft_overrides_p; + // Load the main GGUF struct ggml_context * ctx = NULL; struct gguf_init_params params = { diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h index fe35404b26889..0f52b011b6986 100644 --- a/src/llama-model-loader.h +++ b/src/llama-model-loader.h @@ -77,8 +77,9 @@ struct llama_model_loader { llama_mmaps mappings; - std::map weights_map; - std::unordered_map kv_overrides; + std::map weights_map; + std::unordered_map kv_overrides; + const llama_model_tensor_buft_override * tensor_buft_overrides; gguf_context_ptr meta; std::vector contexts; @@ -95,7 +96,8 @@ struct llama_model_loader { std::vector & splits, // optional, only need if the split does not follow naming scheme bool use_mmap, bool check_tensors, - const struct llama_model_kv_override * param_overrides_p); + const llama_model_kv_override * param_overrides_p, + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p); template typename std::enable_if::value, bool>::type diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 031b4c30b75dd..6b1653536f39e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1444,9 +1444,25 @@ bool llama_model::load_tensors(llama_model_loader & ml) { GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); } - ggml_backend_buffer_type_t buft = select_weight_buft(hparams, t_meta, op, *buft_list); + ggml_backend_buffer_type_t buft = nullptr; + + // check overrides + if (ml.tensor_buft_overrides) { + std::string tensor_name = tn.str(); + for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { + if (tensor_name.find(overrides->pattern) != std::string::npos) { + LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft)); + buft = overrides->buft; + break; + } + } + } + if (!buft) { - throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); + buft = select_weight_buft(hparams, t_meta, op, *buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); + } } // avoid using a host buffer when using mmap @@ -3757,6 +3773,7 @@ const struct ggml_tensor * llama_model::get_tensor(const char * name) const { struct llama_model_params llama_model_default_params() { struct llama_model_params result = { /*.devices =*/ nullptr, + /*.tensor_buft_overrides =*/ nullptr, /*.n_gpu_layers =*/ 0, /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, /*.main_gpu =*/ 0, diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index fb7982655a373..ab50c5d179a29 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -527,7 +527,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } std::vector splits = {}; - llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides); + llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model(llama_model_default_params()); diff --git a/src/llama.cpp b/src/llama.cpp index e8cfe5012819c..e2ca1d7b45c47 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -40,7 +40,7 @@ static int llama_model_load(const std::string & fname, std::vector model.t_start_us = tm.t_start_us; try { - llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides); + llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides, params.tensor_buft_overrides); ml.print_info(); From de538aa32929a10555097f01cad91639dfbe84ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sat, 25 Jan 2025 18:10:22 +0100 Subject: [PATCH 03/17] llama : optimize DeepSeek MLA implementation --- convert_hf_to_gguf.py | 23 ++++++++++ gguf-py/gguf/constants.py | 6 +++ gguf-py/gguf/tensor_mapping.py | 8 ++++ src/llama-arch.cpp | 6 +++ src/llama-arch.h | 2 + src/llama-kv-cache.cpp | 1 + src/llama-kv-cache.h | 4 +- src/llama-model.cpp | 2 + src/llama-model.h | 2 + src/llama.cpp | 83 ++++++++++++++++++---------------- 10 files changed, 96 insertions(+), 41 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 63b54a9cf6b48..4df55e7b15b93 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4136,6 +4136,29 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter else: return [] + if name.endswith("kv_b_proj.weight"): + name_kb = name.replace("kv_b_proj", "k_b_proj") + name_vb = name.replace("kv_b_proj", "v_b_proj") + + n_head_kv = self.hparams["num_key_value_heads"] + v_head_dim = self.hparams["v_head_dim"] + qk_nope_head_dim = self.hparams["qk_nope_head_dim"] + + assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) + + kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) + k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) + k_b = k_b.transpose(1, 2); + k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim) + v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1]) + + return [ + (self.map_tensor_name(name), data_torch), + (self.map_tensor_name(name_kb), k_b), + (self.map_tensor_name(name_vb), v_b) + ] + + return [(self.map_tensor_name(name), data_torch)] def prepare_tensors(self): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 8fe84df21ea20..12522928a8c28 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -356,6 +356,8 @@ class MODEL_TENSOR(IntEnum): ATTN_Q_B = auto() ATTN_KV_A_MQA = auto() ATTN_KV_B = auto() + ATTN_K_B = auto() + ATTN_V_B = auto() ATTN_Q_A_NORM = auto() ATTN_KV_A_NORM = auto() FFN_SUB_NORM = auto() @@ -543,6 +545,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b", + MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b", MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", @@ -1333,6 +1337,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_Q_B, MODEL_TENSOR.ATTN_KV_A_MQA, MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, MODEL_TENSOR.ATTN_Q_A_NORM, MODEL_TENSOR.ATTN_KV_A_NORM, MODEL_TENSOR.ATTN_OUT, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 617791e240b60..df831ba70594c 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -586,6 +586,14 @@ class TensorNameMap: "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 ), + MODEL_TENSOR.ATTN_K_B: ( + "model.layers.{bid}.self_attn.k_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_V_B: ( + "model.layers.{bid}.self_attn.v_b_proj", # deepseek2 + ), + MODEL_TENSOR.ATTN_Q_A_NORM: ( "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 ), diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a7260f495d945..e6daa1bc4b5ce 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -999,6 +999,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, + { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, @@ -1330,6 +1332,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -1347,6 +1351,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 122fdcebe0af6..c6105d59ac1f3 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -277,6 +277,8 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 8a836c784eca5..51e71437c1391 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -105,6 +105,7 @@ bool llama_kv_cache_init( // DeepSeek MLA const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; + LLAMA_LOG_DEBUG("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size); ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size); ggml_format_name(kr, "cache_kr_l%d", i); diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 7f2e1b3e7b144..a87344c849235 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -49,8 +49,8 @@ struct llama_kv_cache { ggml_type type_k = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16; - ggml_type type_kr = GGML_TYPE_F32; - ggml_type type_kv = GGML_TYPE_F32; + ggml_type type_kr = GGML_TYPE_F16; + ggml_type type_kv = GGML_TYPE_F16; std::vector cells; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 031b4c30b75dd..8007e730d04f8 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2870,6 +2870,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); diff --git a/src/llama-model.h b/src/llama-model.h index a7c30444786fd..1fdbd3721d630 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -161,6 +161,8 @@ struct llama_layer { struct ggml_tensor * wq_b = nullptr; struct ggml_tensor * wkv_a_mqa = nullptr; struct ggml_tensor * wkv_b = nullptr; + struct ggml_tensor * wk_b = nullptr; + struct ggml_tensor * wv_b = nullptr; struct ggml_tensor * wq_cross = nullptr; struct ggml_tensor * wk_cross = nullptr; struct ggml_tensor * wv_cross = nullptr; diff --git a/src/llama.cpp b/src/llama.cpp index 5a9518a8e93e2..cb9fe8c9714f5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6483,24 +6483,6 @@ struct llm_build_context { 0); cb(kv_cache, "kv_cache", il); - // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} - struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache); - cb(kv, "kv", il); - - // split into {n_head * n_embd_head_qk_nope, n_tokens} - struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_kv, - ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - 0); - cb(k_nope, "k_nope", il); - - // and {n_head * n_embd_head_v, n_tokens} - struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_kv, - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), - ggml_row_size(kv->type, (n_embd_head_qk_nope))); - cb(v_states, "v_states", il); - q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE q_pe = ggml_rope_ext( ctx0, q_pe, inp_pos, nullptr, @@ -6524,9 +6506,6 @@ struct llm_build_context { // note: storing RoPE-ed version of K^R in the KV cache ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view)); - struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); - cb(q_states, "q_states", il); - struct ggml_tensor * kr_cache = ggml_view_2d(ctx0, kv_self.kr_l[il], n_embd_head_qk_rope, n_kv, @@ -6534,36 +6513,62 @@ struct llm_build_context { 0); cb(kr_cache, "kr_cache", il); - // TODO is there a better way? - struct ggml_tensor * kr_rep_shape = ggml_new_tensor_3d(ctx0, kr_cache->type, kr_cache->ne[0], kr_cache->ne[1], n_head); - struct ggml_tensor * kr_rep = ggml_repeat(ctx0, kr_cache, kr_rep_shape); - kr_rep = ggml_permute(ctx0, kr_rep, 0, 2, 1, 3); - struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, kr_rep, 0); - cb(k_states, "k_states", il); + struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0); + cb(wk_b, "wk_b", il); - q_states = ggml_permute(ctx0, q_states, 0, 2, 1, 3); - cb(q_states, "q_states", il); + struct ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 3, 1); + cb(q_nope_perm, "q_nope_perm", il); - k_states = ggml_permute(ctx0, k_states, 0, 2, 1, 3); - cb(k_states, "k_states", il); + struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm); + cb(q_nope2, "q_nope2", il); - struct ggml_tensor * kq = ggml_mul_mat(ctx0, k_states, q_states); - cb(kq, "kq", il); + struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 1, 3, 2); + cb(q_nope2_perm, "q_nope2_perm", il); + + struct ggml_tensor * kv_cache_perm = ggml_cont(ctx0, ggml_permute(ctx0, kv_cache, 1, 0, 2, 3)); + cb(kv_cache_perm, "kv_cache_perm", il); + + struct ggml_tensor * scores1 = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm); + cb(scores1, "scores1", il); + + struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1); + cb(q_pe_perm, "q_pe_perm", il); + + struct ggml_tensor * kr_cache_perm = ggml_permute(ctx0, kr_cache, 0, 2, 3, 1); + cb(kr_cache_perm, "kr_cache_perm", il); + + struct ggml_tensor * scores2 = ggml_mul_mat(ctx0, kr_cache, q_pe_perm); + cb(scores2, "scores2", il); + + struct ggml_tensor * scores = ggml_add(ctx0, scores1, scores2); + cb(scores, "scores", il); + + struct ggml_tensor * kq = ggml_permute(ctx0, scores, 0, 3, 1, 2); + + struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0); + cb(wv_b, "wv_b", il); kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); - v_states = ggml_permute(ctx0, v_states, 1, 2, 0, 3); - cb(v_states, "v_states", il); + struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 3, 1); + cb(kq_perm, "kq_perm", il); - v_states = ggml_cont(ctx0, v_states); + struct ggml_tensor * kqv1 = ggml_mul_mat(ctx0, kv_cache_perm, kq_perm); + cb(kqv1, "kqv1", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v_states, kq); - cb(kqv, "kqv", il); + struct ggml_tensor * kqv1_trans = ggml_permute(ctx0, kqv1, 0, 1, 3, 2); + cb(kqv1_trans, "kqv1_trans", il); + + struct ggml_tensor * kqv2 = ggml_mul_mat(ctx0, wv_b, kqv1_trans); + cb(kqv2, "kqv2", il); + + struct ggml_tensor * kqv2_trans = ggml_permute(ctx0, kqv2, 0, 3, 2, 1); + cb(kqv2_trans, "kqv2_trans", il); GGML_ASSERT(kv_self.size == n_ctx); - struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv2_trans, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); From ce730637e8fe1b86de7d6e1758f33d716c6c7781 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sun, 26 Jan 2025 12:50:17 +0100 Subject: [PATCH 04/17] llama : Update tensor names in DeepSeek2 MLA implementation. --- src/llama.cpp | 46 +++++++++++++++++++++------------------------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index cb9fe8c9714f5..08b27b33add97 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6525,11 +6525,8 @@ struct llm_build_context { struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 1, 3, 2); cb(q_nope2_perm, "q_nope2_perm", il); - struct ggml_tensor * kv_cache_perm = ggml_cont(ctx0, ggml_permute(ctx0, kv_cache, 1, 0, 2, 3)); - cb(kv_cache_perm, "kv_cache_perm", il); - - struct ggml_tensor * scores1 = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm); - cb(scores1, "scores1", il); + struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm); + cb(kq_nope, "kq_nope", il); struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1); cb(q_pe_perm, "q_pe_perm", il); @@ -6537,13 +6534,14 @@ struct llm_build_context { struct ggml_tensor * kr_cache_perm = ggml_permute(ctx0, kr_cache, 0, 2, 3, 1); cb(kr_cache_perm, "kr_cache_perm", il); - struct ggml_tensor * scores2 = ggml_mul_mat(ctx0, kr_cache, q_pe_perm); - cb(scores2, "scores2", il); + struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe_perm); + cb(kq_pe, "kq_pe", il); - struct ggml_tensor * scores = ggml_add(ctx0, scores1, scores2); - cb(scores, "scores", il); + struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe); + cb(kq, "kq", il); - struct ggml_tensor * kq = ggml_permute(ctx0, scores, 0, 3, 1, 2); + kq = ggml_permute(ctx0, kq, 0, 3, 1, 2); + cb(kq, "kq_perm", il); struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0); cb(wv_b, "wv_b", il); @@ -6552,27 +6550,25 @@ struct llm_build_context { cb(kq, "kq_soft_max_ext", il); struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 3, 1); - cb(kq_perm, "kq_perm", il); - - struct ggml_tensor * kqv1 = ggml_mul_mat(ctx0, kv_cache_perm, kq_perm); - cb(kqv1, "kqv1", il); + cb(kq_perm, "kq_soft_max_ext_perm", il); - struct ggml_tensor * kqv1_trans = ggml_permute(ctx0, kqv1, 0, 1, 3, 2); - cb(kqv1_trans, "kqv1_trans", il); + struct ggml_tensor * kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache)); + cb(kv_cache_trans, "kv_cache_trans", il); - struct ggml_tensor * kqv2 = ggml_mul_mat(ctx0, wv_b, kqv1_trans); - cb(kqv2, "kqv2", il); + struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_perm); + cb(kqv_compressed, "kqv_compressed", il); - struct ggml_tensor * kqv2_trans = ggml_permute(ctx0, kqv2, 0, 3, 2, 1); - cb(kqv2_trans, "kqv2_trans", il); + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 1, 3, 2); + cb(kqv_compressed, "kqv_compressed_perm", il); - GGML_ASSERT(kv_self.size == n_ctx); + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); + cb(kqv, "kqv", il); - struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv2_trans, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); + kqv = ggml_permute(ctx0, kqv, 0, 3, 1, 2); + cb(kqv, "kqv_perm", il); - cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); - cb(cur, "kqv_merged_cont", il); + cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0); + cb(cur, "kqv_2d", il); ggml_build_forward_expand(gf, cur); From 202f323e66809bb1df192245caddc49471660466 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sun, 26 Jan 2025 18:29:54 +0100 Subject: [PATCH 05/17] llama : add a second copy of c^KV cache in DeepSeek2 MLA to avoid transposing the cache during inference --- src/llama-kv-cache.cpp | 6 +++++- src/llama-kv-cache.h | 1 + src/llama.cpp | 16 +++++++++++++--- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 51e71437c1391..57ccbeeae7e26 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -53,7 +53,7 @@ bool llama_kv_cache_init( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { struct ggml_init_params params = { - /*.mem_size =*/ size_t(4u*n_layer*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(5u*n_layer*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -74,6 +74,7 @@ bool llama_kv_cache_init( // DeepSeek MLA cache.kr_l.reserve(n_layer); cache.kv_l.reserve(n_layer); + cache.kvt_l.reserve(n_layer); for (int i = 0; i < n_layer; i++) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); @@ -108,10 +109,13 @@ bool llama_kv_cache_init( LLAMA_LOG_DEBUG("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size); ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size); + ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size); ggml_format_name(kr, "cache_kr_l%d", i); ggml_format_name(kv, "cache_kv_l%d", i); + ggml_format_name(kvt, "cache_kvt_l%d", i); cache.kr_l.push_back(kr); cache.kv_l.push_back(kv); + cache.kvt_l.push_back(kvt); } // allocate tensors and initialize the buffers to avoid NaNs in the padding diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index a87344c849235..b10540d76442e 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -60,6 +60,7 @@ struct llama_kv_cache { // DeepSeek MLA std::vector kr_l; // per layer std::vector kv_l; + std::vector kvt_l; std::vector ctxs; std::vector bufs; diff --git a/src/llama.cpp b/src/llama.cpp index 08b27b33add97..d9fe40102b346 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6476,6 +6476,12 @@ struct llm_build_context { // note: storing c^KV in the KV cache ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view)); + struct ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head)); + cb(kv_cache_trans_view, "kv_cache_trans_view", il); + + // note: storing transposed c^KV in the transposed KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view)); + struct ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, @@ -6483,6 +6489,13 @@ struct llm_build_context { 0); cb(kv_cache, "kv_cache", il); + struct ggml_tensor * kv_cache_trans = + ggml_view_2d(ctx0, kv_self.kvt_l[il], + n_kv, kv_lora_rank, + ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), + 0); + cb(kv_cache_trans, "kv_cache_trans", il); + q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE q_pe = ggml_rope_ext( ctx0, q_pe, inp_pos, nullptr, @@ -6552,9 +6565,6 @@ struct llm_build_context { struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 3, 1); cb(kq_perm, "kq_soft_max_ext_perm", il); - struct ggml_tensor * kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache)); - cb(kv_cache_trans, "kv_cache_trans", il); - struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_perm); cb(kqv_compressed, "kqv_compressed", il); From 93c5937249c313bf825d020f4a5213e32c94737c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sun, 26 Jan 2025 22:23:13 +0100 Subject: [PATCH 06/17] llama : modified tensor permutations to multiply larger matrices during inference --- src/llama.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index d9fe40102b346..3df9896922254 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6529,13 +6529,13 @@ struct llm_build_context { struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0); cb(wk_b, "wk_b", il); - struct ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 3, 1); + struct ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); cb(q_nope_perm, "q_nope_perm", il); struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm); cb(q_nope2, "q_nope2", il); - struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 1, 3, 2); + struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3); cb(q_nope2_perm, "q_nope2_perm", il); struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm); @@ -6547,34 +6547,34 @@ struct llm_build_context { struct ggml_tensor * kr_cache_perm = ggml_permute(ctx0, kr_cache, 0, 2, 3, 1); cb(kr_cache_perm, "kr_cache_perm", il); - struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe_perm); + struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe); cb(kq_pe, "kq_pe", il); struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe); cb(kq, "kq", il); - kq = ggml_permute(ctx0, kq, 0, 3, 1, 2); + kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); cb(kq, "kq_perm", il); - struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0); - cb(wv_b, "wv_b", il); - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); - struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 3, 1); + struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 1, 3); cb(kq_perm, "kq_soft_max_ext_perm", il); struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_perm); cb(kqv_compressed, "kqv_compressed", il); - kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 1, 3, 2); + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); cb(kqv_compressed, "kqv_compressed_perm", il); + struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0); + cb(wv_b, "wv_b", il); + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); cb(kqv, "kqv", il); - kqv = ggml_permute(ctx0, kqv, 0, 3, 1, 2); + kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); cb(kqv, "kqv_perm", il); cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0); From 1eee98f01fca721a889defac3d38e9ada7abb617 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Mon, 27 Jan 2025 09:32:25 +0100 Subject: [PATCH 07/17] llama : removed unnecessary code in DeepSeek V2 implementation --- src/llama.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 3df9896922254..a4c78240b265e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6544,9 +6544,6 @@ struct llm_build_context { struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1); cb(q_pe_perm, "q_pe_perm", il); - struct ggml_tensor * kr_cache_perm = ggml_permute(ctx0, kr_cache, 0, 2, 3, 1); - cb(kr_cache_perm, "kr_cache_perm", il); - struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe); cb(kq_pe, "kq_pe", il); From 8ff0991eed65e4041e6e3dfa2e3c98aee7fa2c21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Tue, 28 Jan 2025 11:02:52 +0100 Subject: [PATCH 08/17] convert : make lint happy --- convert_hf_to_gguf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 4df55e7b15b93..2be7de5a59bbe 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4148,7 +4148,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) - k_b = k_b.transpose(1, 2); + k_b = k_b.transpose(1, 2) k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim) v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1]) @@ -4158,7 +4158,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter (self.map_tensor_name(name_vb), v_b) ] - return [(self.map_tensor_name(name), data_torch)] def prepare_tensors(self): From 8a887decd35083c1542534cfadc3a5ee592da964 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Tue, 28 Jan 2025 19:26:54 +0100 Subject: [PATCH 09/17] llama : prompt processing optimizations in DeepSeek V2 --- src/llama.cpp | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index a4c78240b265e..5768f9215fea9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6403,6 +6403,10 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + // whether to use n_tokens as the matrix dimension during multiplication or n_head + // n_tokens is higher during prompt processing, this allows to optimize for this case + bool pp_opt = n_tokens > n_head; + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -6535,14 +6539,18 @@ struct llm_build_context { struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm); cb(q_nope2, "q_nope2", il); - struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3); - cb(q_nope2_perm, "q_nope2_perm", il); + if (!pp_opt) { + q_nope2 = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3); + cb(q_nope2, "q_nope2_perm", il); + } - struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm); + struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2); cb(kq_nope, "kq_nope", il); - struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1); - cb(q_pe_perm, "q_pe_perm", il); + if (pp_opt) { + q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3); + cb(q_pe, "q_pe_perm", il); + } struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe); cb(kq_pe, "kq_pe", il); @@ -6550,20 +6558,26 @@ struct llm_build_context { struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe); cb(kq, "kq", il); - kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); - cb(kq, "kq_perm", il); + if (!pp_opt) { + kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); + cb(kq, "kq_perm", il); + } kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); - struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 1, 3); - cb(kq_perm, "kq_soft_max_ext_perm", il); + if (!pp_opt) { + kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); + cb(kq, "kq_soft_max_ext_perm", il); + } - struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_perm); + struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); cb(kqv_compressed, "kqv_compressed", il); - kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); - cb(kqv_compressed, "kqv_compressed_perm", il); + if (!pp_opt) { + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + cb(kqv_compressed, "kqv_compressed_perm", il); + } struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0); cb(wv_b, "wv_b", il); From 76543311acc85e1d77575728000f1979faa7591f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Thu, 30 Jan 2025 18:25:36 +0100 Subject: [PATCH 10/17] llama : avoid ggml_cont() is possible in DeepSeek V2 implementation --- src/llama.cpp | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 5768f9215fea9..1a3d1d0bda9d2 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6533,10 +6533,10 @@ struct llm_build_context { struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0); cb(wk_b, "wk_b", il); - struct ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); - cb(q_nope_perm, "q_nope_perm", il); + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); - struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm); + struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope); cb(q_nope2, "q_nope2", il); if (!pp_opt) { @@ -6547,6 +6547,11 @@ struct llm_build_context { struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2); cb(kq_nope, "kq_nope", il); + if (!pp_opt) { + kq_nope = ggml_permute(ctx0, kq_nope, 0, 2, 1, 3); + cb(kq_nope, "kq_nope_perm", il); + } + if (pp_opt) { q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3); cb(q_pe, "q_pe_perm", il); @@ -6555,14 +6560,14 @@ struct llm_build_context { struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe); cb(kq_pe, "kq_pe", il); - struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe); - cb(kq, "kq", il); - if (!pp_opt) { - kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); - cb(kq, "kq_perm", il); + kq_pe = ggml_permute(ctx0, kq_pe, 0, 2, 1, 3); + cb(kq_pe, "kq_pe_perm", il); } + struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe); + cb(kq, "kq", il); + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); @@ -6575,7 +6580,7 @@ struct llm_build_context { cb(kqv_compressed, "kqv_compressed", il); if (!pp_opt) { - kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 3, 1); cb(kqv_compressed, "kqv_compressed_perm", il); } @@ -6585,8 +6590,10 @@ struct llm_build_context { struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); cb(kqv, "kqv", il); - kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); - cb(kqv, "kqv_perm", il); + if (pp_opt) { + kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); + cb(kqv, "kqv_perm", il); + } cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0); cb(cur, "kqv_2d", il); From 83a473a00133fe9ba66fec54cea3cae8df275ca4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sat, 1 Feb 2025 10:32:06 +0100 Subject: [PATCH 11/17] llama : use all experts during warmup --- src/llama.cpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 192b20a27e5ca..a8258becdfeb4 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1092,7 +1092,8 @@ struct llm_build_context { llama_context & lctx, const llama_ubatch & ubatch, const llm_build_cb & cb, - bool worst_case) : + bool worst_case, + bool warmup) : model (lctx.model), lctx (lctx), hparams (model.hparams), @@ -1110,7 +1111,7 @@ struct llm_build_context { n_embd_head_v (hparams.n_embd_head_v), n_embd_v_gqa (hparams.n_embd_v_gqa()), n_expert (hparams.n_expert), - n_expert_used (hparams.n_expert_used), + n_expert_used (warmup ? hparams.n_expert : hparams.n_expert_used), freq_base (cparams.rope_freq_base), freq_scale (cparams.rope_freq_scale), ext_factor (cparams.yarn_ext_factor), @@ -8103,7 +8104,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - struct llm_build_context llm(lctx, dummy, cb, false); + struct llm_build_context llm(lctx, dummy, cb, false, false); llm.init(); @@ -8120,7 +8121,7 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - struct llm_build_context llm(lctx, dummy, cb, false); + struct llm_build_context llm(lctx, dummy, cb, false, false); llm.init(); @@ -8171,7 +8172,11 @@ static struct ggml_cgraph * llama_build_graph( struct ggml_cgraph * result = NULL; - struct llm_build_context llm(lctx, ubatch, cb, worst_case); + const llama_vocab * vocab = llama_model_get_vocab(&model); + llama_token bos = llama_vocab_bos(vocab); + llama_token eos = llama_vocab_eos(vocab); + bool is_warming_up = (ubatch.n_tokens == 2 && ubatch.token[0] == bos && ubatch.token[1] == eos); + struct llm_build_context llm(lctx, ubatch, cb, worst_case, is_warming_up); llm.init(); From c8bc6e4ff4b9f1cb1e94eb56ddd10a95bd0108da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sat, 1 Feb 2025 12:43:14 +0100 Subject: [PATCH 12/17] llama : increased max_nodes as large MoE models use massive amounts of nodes during warmup --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 18bd0b071bb90..c958edb873a03 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3552,7 +3552,7 @@ size_t llama_model::size() const { } size_t llama_model::max_nodes() const { - return std::max(8192, tensors_by_name.size()*5); + return std::max(65536, tensors_by_name.size()*5); } size_t llama_model::n_devices() const { From 6c8d01a8bbe0d64491608089027c26ac85cce262 Mon Sep 17 00:00:00 2001 From: slaren Date: Sun, 2 Feb 2025 17:23:32 +0100 Subject: [PATCH 13/17] add regex support --- src/llama-model.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f1cba4f39a676..f134d1bf1ef2a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -1464,7 +1465,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (ml.tensor_buft_overrides) { std::string tensor_name = tn.str(); for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { - if (tensor_name.find(overrides->pattern) != std::string::npos) { + std::regex pattern(overrides->pattern); + if (std::regex_search(tensor_name, pattern)) { LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft)); buft = overrides->buft; break; From 538f60934abd36f19598d74518cdef0ccd18a023 Mon Sep 17 00:00:00 2001 From: slaren Date: Thu, 6 Feb 2025 01:32:04 +0100 Subject: [PATCH 14/17] ggml : fix possible underflow in ggml_nbytes --- ggml/src/ggml.c | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 3b48615421187..52c553e76b29f 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1151,6 +1151,12 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) { } size_t ggml_nbytes(const struct ggml_tensor * tensor) { + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + if (tensor->ne[i] <= 0) { + return 0; + } + } + size_t nbytes; const size_t blck_size = ggml_blck_size(tensor->type); if (blck_size == 1) { From 8770ffa60c0d0eac481f199f2da1bb6b622a8207 Mon Sep 17 00:00:00 2001 From: slaren Date: Sun, 9 Feb 2025 00:32:52 +0100 Subject: [PATCH 15/17] rebuild buft list on every call --- common/arg.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 5a98c4baf3a83..e796d0e85f946 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1485,13 +1485,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex add_opt(common_arg( {"--override-tensor", "-ot"}, "=,...", "override tensor buffer type", [](common_params & params, const std::string & value) { - static std::map buft_list; + /* static */ std::map buft_list; if (buft_list.empty()) { // enumerate all the devices and add their buffer types to the list for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { auto * dev = ggml_backend_dev_get(i); auto * buft = ggml_backend_dev_buffer_type(dev); - buft_list[ggml_backend_buft_name(buft)] = buft; + if (buft) { + buft_list[ggml_backend_buft_name(buft)] = buft; + } } } From 0d4ff95b8270b481c4131795925a0b7abdc657bd Mon Sep 17 00:00:00 2001 From: Orca Date: Tue, 25 Feb 2025 20:41:08 +0800 Subject: [PATCH 16/17] can shift --- examples/server/server.cpp | 29 ++++++++++++++++++----------- src/llama-kv-cache.cpp | 2 +- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2306dc26fe431..c4db6642e9ef6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1671,9 +1671,8 @@ struct server_response { } void add_waiting_tasks(const std::vector & tasks) { - std::unique_lock lock(mutex_results); - for (const auto & task : tasks) { + std::unique_lock lock(mutex_results); SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size()); waiting_task_ids.insert(task.id); } @@ -1683,20 +1682,24 @@ struct server_response { void remove_waiting_task_id(int id_task) { SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); - std::unique_lock lock(mutex_results); - waiting_task_ids.erase(id_task); + { + std::unique_lock lock(mutex_results); + waiting_task_ids.erase(id_task); + } // make sure to clean up all pending results - queue_results.erase( - std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { - return res->id == id_task; - }), - queue_results.end()); + { + std::unique_lock lock(mutex_results); + queue_results.erase( + std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { + return res->id == id_task; + }), + queue_results.end()); + } } void remove_waiting_task_ids(const std::unordered_set & id_tasks) { - std::unique_lock lock(mutex_results); - for (const auto & id_task : id_tasks) { + std::unique_lock lock(mutex_results); SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); waiting_task_ids.erase(id_task); } @@ -3841,6 +3844,10 @@ int main(int argc, char ** argv) { // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); + if (prompt.contains("chat_history")) { + return; + } + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); tasks.reserve(tokenized_prompts.size()); for (size_t i = 0; i < tokenized_prompts.size(); i++) { diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index feffdf0de52cf..b5fbb3a25f0b6 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -32,7 +32,7 @@ bool llama_kv_cache_init( cache.recurrent = llama_model_is_recurrent(&model); cache.v_trans = !cache.recurrent && !cparams.flash_attn; - cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA + cache.can_shift = !cache.recurrent; // not supported due to MLA LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, cache.can_shift); From d256aa04cecdebe286bee5036c60cff852144e34 Mon Sep 17 00:00:00 2001 From: Orca Date: Tue, 25 Feb 2025 20:43:49 +0800 Subject: [PATCH 17/17] tmp --- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- src/llama-kv-cache.cpp | 4 ++-- src/llama-quant.cpp | 2 +- src/llama.cpp | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ebb2ccae04065..5725fc375cc89 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3248,7 +3248,7 @@ static int64_t get_op_batch_size(const ggml_tensor * op) { } static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - const int min_batch_size = 32; + const int min_batch_size = 9999999; return get_op_batch_size(op) >= min_batch_size; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 7ded592c289e1..4d43c692ea4bd 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -96,8 +96,8 @@ bool llama_kv_cache_init( return false; } - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, 1); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, 1); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); cache.k_l.push_back(k); diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index ab50c5d179a29..cabb5f8f8cdd5 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -776,7 +776,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // get more optimal quantization type based on the tensor shape, layer, etc. if (!params->pure && ggml_is_quantized(default_type)) { - new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); + new_type = name.find("_exps") != std::string::npos ? name.find("ffn_down") != std::string::npos ? GGML_TYPE_Q6_K : GGML_TYPE_Q5_K : GGML_TYPE_BF16; } if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { new_type = params->token_embedding_type; diff --git a/src/llama.cpp b/src/llama.cpp index 6cefcc7912eb5..c1aa5380498c6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6407,7 +6407,7 @@ struct llm_build_context { // whether to use n_tokens as the matrix dimension during multiplication or n_head // n_tokens is higher during prompt processing, this allows to optimize for this case - bool pp_opt = n_tokens > n_head; + bool pp_opt = true; for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL;