diff --git a/README.md b/README.md index 540c29a4f1847..576332bc5d33d 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
Bindings +- Python: [ddh0/easy-llama](https://github.com/ddh0/easy-llama) - Python: [abetlen/llama-cpp-python](https://github.com/abetlen/llama-cpp-python) - Go: [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) - Node.js: [withcatai/node-llama-cpp](https://github.com/withcatai/node-llama-cpp) diff --git a/common/arg.cpp b/common/arg.cpp index cfa9878f90730..0d0daa3610105 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2869,6 +2869,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "(default: deepseek)", [](common_params & params, const std::string & value) { /**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; } + else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; } else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; } else { throw std::invalid_argument("invalid value"); } } diff --git a/common/chat.cpp b/common/chat.cpp index f1ab4c85a913e..1d6974a8c563b 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -82,10 +82,10 @@ json common_chat_msg::to_json_oaicompat() const std::vector common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) { std::vector diffs; - // if (previous_msg.reasoning_content != current.reasoning_content) { - // auto & diff = diffs.emplace_back(); - // diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, current.reasoning_content); - // } + if (previous_msg.reasoning_content != new_msg.reasoning_content) { + auto & diff = diffs.emplace_back(); + diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content); + } if (previous_msg.content != new_msg.content) { auto & diff = diffs.emplace_back(); diff.content_delta = string_diff(previous_msg.content, new_msg.content); @@ -385,9 +385,9 @@ json common_chat_tools_to_json_oaicompat(const std::vector & t template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { json delta = json::object(); - // if (!diff.reasoning_content_delta.empty()) { - // delta["reasoning_content"] = msg.reasoning_content; - // } + if (!diff.reasoning_content_delta.empty()) { + delta["reasoning_content"] = diff.reasoning_content_delta; + } if (!diff.content_delta.empty()) { delta["content"] = diff.content_delta; } @@ -598,6 +598,7 @@ const char * common_reasoning_format_name(common_reasoning_format format) { switch (format) { case COMMON_REASONING_FORMAT_NONE: return "none"; case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek"; + case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy"; default: throw std::runtime_error("Unknown reasoning format"); } diff --git a/common/chat.h b/common/chat.h index f6b1d0ffcc989..9f59e6b08738d 100644 --- a/common/chat.h +++ b/common/chat.h @@ -70,7 +70,7 @@ struct common_chat_msg { }; struct common_chat_msg_diff { - // std::string reasoning_content_delta; + std::string reasoning_content_delta; std::string content_delta; size_t tool_call_index = std::string::npos; common_chat_tool_call tool_call_delta; diff --git a/common/common.h b/common/common.h index cee1e3039cf9e..f26724b6e1495 100644 --- a/common/common.h +++ b/common/common.h @@ -215,7 +215,8 @@ struct common_params_vocoder { enum common_reasoning_format { COMMON_REASONING_FORMAT_NONE, - COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content` + COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in tags in stream mode + COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas. }; struct common_params { diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ab0f0e0ea087e..ec3b5697d8f6f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3814,7 +3814,7 @@ def _xlmroberta_set_vocab(self) -> None: remove_whitespaces = tokenizer.clean_up_tokenization_spaces precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"]) - vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size) + vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size) else: sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) @@ -3827,7 +3827,7 @@ def _xlmroberta_set_vocab(self) -> None: tokenizer = SentencePieceProcessor() tokenizer.LoadFromFile(str(tokenizer_path)) - vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size()) tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size @@ -3857,33 +3857,26 @@ def _xlmroberta_set_vocab(self) -> None: unk_token = tokenizer_config_json.get("unk_token") unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3)) - for token_id in range(vocab_size): + for token_id in range(tokenizer.vocab_size): piece = tokenizer._convert_id_to_token(token_id) - text = piece.encode("utf-8") - score = tokenizer_json["model"]["vocab"][token_id][1] - - toktype = SentencePieceTokenTypes.NORMAL - if token_id == unk_token_id: - toktype = SentencePieceTokenTypes.UNKNOWN - elif token_id in tokenizer.all_special_ids: - toktype = SentencePieceTokenTypes.CONTROL - elif token_id in added_vocab.values(): - toktype = SentencePieceTokenTypes.USER_DEFINED - # No reliable way to detect this, but jina doesn't have any - # elif tokenizer.IsByte(token_id): - # toktype = SentencePieceTokenTypes.BYTE - - tokens[token_id] = text - scores[token_id] = score - toktypes[token_id] = toktype - - if vocab_size > len(tokens): - pad_count = vocab_size - len(tokens) - logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") - for i in range(1, pad_count + 1): - tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) - scores.append(-1000.0) - toktypes.append(SentencePieceTokenTypes.UNUSED) + if (piece := tokenizer._convert_id_to_token(token_id)) is not None: + text = piece.encode("utf-8") + score = tokenizer_json["model"]["vocab"][token_id][1] + + toktype = SentencePieceTokenTypes.NORMAL + if token_id == unk_token_id: + toktype = SentencePieceTokenTypes.UNKNOWN + elif token_id in tokenizer.all_special_ids: + toktype = SentencePieceTokenTypes.CONTROL + elif token_id in added_vocab.values(): + toktype = SentencePieceTokenTypes.USER_DEFINED + # No reliable way to detect this, but jina doesn't have any + # elif tokenizer.IsByte(token_id): + # toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype if isinstance(tokenizer, SentencePieceProcessor): # realign tokens (see HF tokenizer code) @@ -3896,6 +3889,12 @@ def _xlmroberta_set_vocab(self) -> None: SentencePieceTokenTypes.UNKNOWN, ] + toktypes[3:-1] + if self.model_arch == gguf.MODEL_ARCH.NOMIC_BERT_MOE: + # Add mask token missing from sentencepiece.bpe.model + tokens[250001] = b'' + scores[250001] = 0.0 + toktypes[250001] = SentencePieceTokenTypes.CONTROL + self.gguf_writer.add_tokenizer_model("t5") self.gguf_writer.add_tokenizer_pre("default") self.gguf_writer.add_token_list(tokens) diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index d7b269df0dea2..cd85bea9ac857 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -158,7 +158,7 @@ int main(int argc, char ** argv) { common_params params; params.n_predict = 128; - params.n_junk = 0; + params.n_junk = 1; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) { return 1; @@ -182,7 +182,7 @@ int main(int argc, char ** argv) { const bool is_sp_shared = params.is_pp_shared; // extra text to insert in each client's prompt in order to make it larger - const int32_t n_junk = params.n_junk; + const int32_t n_junk = std::max(1, params.n_junk); // init llama.cpp llama_backend_init(); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 2226aadcff893..1a57f1cd75a31 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2095,9 +2095,6 @@ extern "C" { GGML_API struct ggml_tensor * ggml_graph_get_grad (const struct ggml_cgraph * cgraph, const struct ggml_tensor * node); GGML_API struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node); - GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname); - GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval); - // print info and performance information for the graph GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph); diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 83daa69eed5b1..edcea68f20fb9 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -196,6 +196,7 @@ add_library(ggml-base ../include/ggml-opt.h ../include/gguf.h ggml.c + ggml.cpp ggml-alloc.c ggml-backend.cpp ggml-opt.cpp @@ -234,6 +235,7 @@ function(ggml_add_backend_library backend) set_target_properties(${backend} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) target_compile_definitions(${backend} PRIVATE GGML_BACKEND_DL) add_dependencies(ggml ${backend}) + install(TARGETS ${backend} LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR}) else() add_library(${backend} ${ARGN}) target_link_libraries(ggml PUBLIC ${backend}) diff --git a/ggml/src/ggml-blas/CMakeLists.txt b/ggml/src/ggml-blas/CMakeLists.txt index 41b104b892ad8..3cb43436163f6 100644 --- a/ggml/src/ggml-blas/CMakeLists.txt +++ b/ggml/src/ggml-blas/CMakeLists.txt @@ -96,5 +96,4 @@ else() message(ERROR "BLAS not found, please refer to " "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" " to set correct GGML_BLAS_VENDOR") - endif() endif() diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index f63051bd3db12..9998b3afb6875 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -345,7 +345,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name) execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M) endif() - string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}") + string(TOUPPER "${POWER10_M}" POWER10_M_UPPER) + string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M_UPPER}") string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}") if (EXTRACTED_NUMBER GREATER_EQUAL 10) diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 89b59d9aadc7e..6dc5ce0d92fd8 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -32,6 +32,8 @@ extern "C" { #endif +void ggml_print_backtrace(void); + #ifndef MIN # define MIN(a, b) ((a) < (b) ? (a) : (b)) #endif diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index f78e7eee553b6..bc93bc633a49b 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -4766,6 +4766,8 @@ static bool ggml_metal_encode_node( GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); + const int is_q = ggml_is_quantized(src1->type) ? 1 : 0; + // 2*(2*ncpsg + nqptg)*(nsg) // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float) // @@ -4773,7 +4775,7 @@ static bool ggml_metal_encode_node( // the shared memory needed for the simdgroups to load the KV cache // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; @@ -4810,9 +4812,9 @@ static bool ggml_metal_encode_node( // and store the soft_max values and the mask // // ne00*(nsg) - // each simdgroup has a full f16 head vector in shared mem to accumulate results + // each simdgroup has a full f32 head vector in shared mem to accumulate results // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; while (true) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 59899550ed38c..58763e39e8353 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3328,14 +3328,14 @@ kernel void kernel_flash_attn_ext( constexpr short NW = N_SIMDWIDTH; constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) - const short TS = nsg*SH; // shared memory size per query in (s_t == float) - const short T = DK + 2*TS; // shared memory size per query in (half) + const short TS = nsg*SH; // shared memory size per query in (s_t == float) + const short T = 2*DK + 2*TS; // shared memory size per query in (half) - threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation - threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t @@ -3354,7 +3354,7 @@ kernel void kernel_flash_attn_ext( if (iq1 + j < args.ne01) { sq4[j*DK4 + i] = (q4_t) q4[i]; } else { - sq4[j*DK4 + i] = (q4_t) 0.0f; + sq4[j*DK4 + i] = 0; } } } @@ -3634,9 +3634,6 @@ kernel void kernel_flash_attn_ext( // reduce the warps sequentially for (ushort sg = 1; sg < nsg; ++sg) { - float S = { 0.0f }; - float M = { -__FLT_MAX__/2 }; - threadgroup_barrier(mem_flags::mem_threadgroup); // each simdgroup stores its output to shared memory, reusing sq @@ -3657,12 +3654,12 @@ kernel void kernel_flash_attn_ext( const float M0 = ss[j*TS + 1]; const float M1 = ss[j*TS + sg*SH + 1]; - M = max(M0, M1); + const float M = max(M0, M1); const float ms0 = exp(M0 - M); const float ms1 = exp(M1 - M); - S = S0*ms0 + S1*ms1; + const float S = S0*ms0 + S1*ms1; if (tiisg == 0) { ss[j*TS + 0] = S; @@ -3701,16 +3698,18 @@ kernel void kernel_flash_attn_ext( } } - device float4 * dst4 = (device float4 *) dst; + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK); // final rescale with 1/S and store to global memory - if (sgitg == 0) { - for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) { - const float S = ss[j*TS + 0]; + for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) { + const float S = 1.0f/sf[j*TS + 0]; - for (short i = tiisg; i < DV4; i += NW) { - dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S; - } + device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; + + for (short i = tiisg; i < DV4; i += NW) { + dst4[i] = (float4) so4[j*DV4 + i]*S; } } } @@ -3719,12 +3718,22 @@ kernel void kernel_flash_attn_ext( // template to be able to explore different combinations // #define FA_TYPES \ - half, half4, simdgroup_half8x8, \ - half, half4x4, simdgroup_half8x8, \ - half, half4x4, simdgroup_half8x8, \ - float, simdgroup_float8x8, \ - float, simdgroup_float8x8, \ - half, half4, simdgroup_half8x8 + float, float4, simdgroup_float8x8, \ + half, half4x4, simdgroup_half8x8, \ + half, half4x4, simdgroup_half8x8, \ + float, simdgroup_float8x8, \ + float, simdgroup_float8x8, \ + float, float4, simdgroup_float8x8 + //half, half4, simdgroup_half8x8 + +#define FA_TYPES_BF \ + bfloat, bfloat4, simdgroup_bfloat8x8, \ + bfloat, bfloat4x4, simdgroup_bfloat8x8, \ + bfloat, bfloat4x4, simdgroup_bfloat8x8, \ + float, simdgroup_float8x8, \ + float, simdgroup_float8x8, \ + float, float4, simdgroup_float8x8 + //half, half4, simdgroup_half8x8 typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; @@ -3739,15 +3748,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -3801,6 +3810,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES +#undef FA_TYPES_BF template< typename q4_t, // query types in shared memory @@ -3847,12 +3857,12 @@ kernel void kernel_flash_attn_ext_vec( const short T = DK + nsg*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t - threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask - threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t + threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask + threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results // store the result for all queries in local memory (the O matrix from the paper) o4_t lo[DV4/NL]; @@ -4157,7 +4167,7 @@ kernel void kernel_flash_attn_ext_vec( half4, \ float, \ float, float4, \ - half4 + float4 typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 6b86d1f63e132..a1d1c43e7345e 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -96,6 +96,12 @@ set(GGML_OPENCL_KERNELS sub sum_rows transpose + concat + tsembd + upscale + tanh + pad + repeat ) foreach (K ${GGML_OPENCL_KERNELS}) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index b0c1c2ebf2032..e2fe3e978058f 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -315,6 +315,12 @@ struct ggml_backend_opencl_context { cl_program program_softmax_4_f16; cl_program program_argsort_f32_i32; cl_program program_sum_rows_f32; + cl_program program_repeat; + cl_program program_pad; + cl_program program_tanh; + cl_program program_upscale; + cl_program program_concat; + cl_program program_tsembd; cl_kernel kernel_add, kernel_add_row; cl_kernel kernel_mul, kernel_mul_row; @@ -351,6 +357,15 @@ struct ggml_backend_opencl_context { cl_kernel kernel_im2col_f32, kernel_im2col_f16; cl_kernel kernel_argsort_f32_i32; cl_kernel kernel_sum_rows_f32; + cl_kernel kernel_repeat; + cl_kernel kernel_pad; + cl_kernel kernel_tanh_f32_nd; + cl_kernel kernel_tanh_f16_nd; + cl_kernel kernel_upscale; + cl_kernel kernel_upscale_bilinear; + cl_kernel kernel_concat_f32_contiguous; + cl_kernel kernel_concat_f32_non_contiguous; + cl_kernel kernel_timestep_embedding; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Transpose kernels @@ -1097,6 +1112,150 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // repeat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "repeat.cl.h" + }; +#else + const std::string kernel_src = read_file("repeat.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_repeat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_repeat = clCreateKernel(backend_ctx->program_repeat, "kernel_repeat", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: repeat kernel source not found or empty. Repeat operations will not be available.\n"); + backend_ctx->program_repeat = nullptr; + backend_ctx->kernel_repeat = nullptr; + } + } + + // pad + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "pad.cl.h" + }; +#else + const std::string kernel_src = read_file("pad.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_pad = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_pad = clCreateKernel(backend_ctx->program_pad, "kernel_pad", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: pad kernel source not found or empty. Pad operations will not be available.\n"); + backend_ctx->program_pad = nullptr; + backend_ctx->kernel_pad = nullptr; + } + } + + // tanh + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "tanh.cl.h" + }; +#else + const std::string kernel_src = read_file("tanh.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_tanh = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_tanh_f32_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f32_nd", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f16_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f16_nd", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: tanh kernel source not found or empty. Tanh operation will not be available.\n"); + backend_ctx->program_tanh = nullptr; + backend_ctx->kernel_tanh_f32_nd = nullptr; + backend_ctx->kernel_tanh_f16_nd = nullptr; + } + } + + // upscale + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "upscale.cl.h" + }; +#else + const std::string kernel_src = read_file("upscale.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_upscale = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_upscale = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale", &err), err)); + if (backend_ctx->program_upscale) { + cl_int err_bilinear; + backend_ctx->kernel_upscale_bilinear = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale_bilinear", &err_bilinear); + if (err_bilinear != CL_SUCCESS) { + GGML_LOG_WARN("ggml_opencl: kernel_upscale_bilinear not found in upscale.cl. Bilinear upscale will not be available. Error: %d\n", err_bilinear); + backend_ctx->kernel_upscale_bilinear = nullptr; + } + } else { + backend_ctx->kernel_upscale_bilinear = nullptr; + } + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: upscale kernel source not found or empty. Upscale operations will not be available.\n"); + backend_ctx->program_upscale = nullptr; + backend_ctx->kernel_upscale = nullptr; + backend_ctx->kernel_upscale_bilinear = nullptr; + } + } + + // concat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "concat.cl.h" + }; +#else + + const std::string kernel_src = read_file("concat.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_concat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_concat_f32_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_contiguous", &err), err)); + CL_CHECK((backend_ctx->kernel_concat_f32_non_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_non_contiguous", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: concat kernel source not found or empty. Concat operations will not be available.\n"); + backend_ctx->program_concat = nullptr; + backend_ctx->kernel_concat_f32_contiguous = nullptr; + backend_ctx->kernel_concat_f32_non_contiguous = nullptr; + } + } + + // timestep_embedding + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "tsembd.cl.h" + }; +#else + + const std::string kernel_src = read_file("tsembd.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_tsembd = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_timestep_embedding = clCreateKernel(backend_ctx->program_tsembd, "kernel_timestep_embedding", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: timestep_embedding kernel source not found or empty. This op will not be available.\n"); + backend_ctx->program_tsembd = nullptr; + backend_ctx->kernel_timestep_embedding = nullptr; + } + } + // Adreno kernels #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // transpose @@ -1863,7 +2022,12 @@ static bool ggml_backend_opencl_cpy_tensor_async(ggml_backend_t backend, const g } static void ggml_backend_opencl_synchronize(ggml_backend_t backend) { - GGML_UNUSED(backend); + auto * backend_ctx = static_cast(backend->context); + + cl_event evt; + CL_CHECK(clEnqueueBarrierWithWaitList(backend_ctx->queue, 0, nullptr, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseEvent(evt)); } // Syncronizes the 'backend_ctx's device with others so that commands @@ -1976,9 +2140,12 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_GELU_QUICK: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_UNARY_OP_SIGMOID: return ggml_is_contiguous(op->src[0]); + case GGML_UNARY_OP_TANH: + return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || + (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16); default: return false; } @@ -1988,6 +2155,17 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_NORM: case GGML_OP_RMS_NORM: return true; + case GGML_OP_REPEAT: + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded + case GGML_OP_PAD: + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && + op->src[0]->ne[3] == 1 && op->ne[3] == 1; + case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_CONCAT: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_TIMESTEP_EMBEDDING: + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_GROUP_NORM: return ggml_is_contiguous(op->src[0]); case GGML_OP_MUL_MAT: @@ -2052,7 +2230,7 @@ static ggml_backend_i ggml_backend_opencl_i = { /* .set_tensor_async = */ NULL, /* ggml_backend_opencl_set_tensor_async */ /* .get_tensor_async = */ NULL, /* ggml_backend_opencl_get_tensor_async */ /* .cpy_tensor_async = */ NULL, /* ggml_backend_opencl_cpy_tensor_async */ - /* .synchronize = */ NULL, /* ggml_backend_opencl_synchronize */ + /* .synchronize = */ ggml_backend_opencl_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, /* .graph_plan_update = */ NULL, @@ -4108,6 +4286,536 @@ static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, #endif } +static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0_abs = extra0->offset + src0->view_offs; + cl_ulong offsetd_abs = extrad->offset + dst->view_offs; + + cl_kernel kernel; + if (dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_tanh_f32_nd; + } else if (dst->type == GGML_TYPE_F16) { + kernel = backend_ctx->kernel_tanh_f16_nd; + } else { + GGML_ASSERT(false && "Unsupported type for ggml_cl_tanh"); + } + GGML_ASSERT(kernel != nullptr); + + const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; const int ne03 = src0->ne[3]; + const cl_ulong nb00 = src0->nb[0]; const cl_ulong nb01 = src0->nb[1]; const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3]; + + const int ne10 = dst->ne[0]; const int ne11 = dst->ne[1]; const int ne12 = dst->ne[2]; const int ne13 = dst->ne[3]; + const cl_ulong nb10 = dst->nb[0]; const cl_ulong nb11 = dst->nb[1]; const cl_ulong nb12 = dst->nb[2]; const cl_ulong nb13 = dst->nb[3]; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs)); + + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03)); + + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13)); + + size_t global_work_size[3]; + if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements + return; + } + global_work_size[0] = (size_t)ne10; + global_work_size[1] = (size_t)ne11; + global_work_size[2] = (size_t)ne12; + + size_t lws0 = 16, lws1 = 4, lws2 = 1; + if (ne10 < 16) lws0 = ne10; + if (ne11 < 4) lws1 = ne11; + if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1; + + while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2; + while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2; + while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2; + + + size_t local_work_size[] = {lws0, lws1, lws2}; + + size_t* local_work_size_ptr = local_work_size; + if (!backend_ctx->non_uniform_workgroups) { + if (global_work_size[0] % local_work_size[0] != 0 || + global_work_size[1] % local_work_size[1] != 0 || + global_work_size[2] % local_work_size[2] != 0) { + local_work_size_ptr = NULL; + } + } + if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return; + + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr ? local_work_size : (size_t[3]){0,0,0}, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_ASSERT(dst->type == src0->type); + + UNUSED(src1_shape_def); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + if (backend_ctx->kernel_repeat == nullptr) { + GGML_LOG_WARN("%s: repeat kernel not available, skipping OpenCL execution.\n", __func__); + return; + } + + ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong off_src0 = extra_src0->offset + src0->view_offs; + cl_ulong off_dst = extra_dst->offset + dst->view_offs; + + const int src0_ne0 = src0->ne[0]; const int src0_ne1 = src0->ne[1]; const int src0_ne2 = src0->ne[2]; const int src0_ne3 = src0->ne[3]; + const cl_ulong src0_nb0 = src0->nb[0]; const cl_ulong src0_nb1 = src0->nb[1]; const cl_ulong src0_nb2 = src0->nb[2]; const cl_ulong src0_nb3 = src0->nb[3]; + + const int dst_ne0 = dst->ne[0]; const int dst_ne1 = dst->ne[1]; const int dst_ne2 = dst->ne[2]; const int dst_ne3 = dst->ne[3]; + const cl_ulong dst_nb0 = dst->nb[0]; const cl_ulong dst_nb1 = dst->nb[1]; const cl_ulong dst_nb2 = dst->nb[2]; const cl_ulong dst_nb3 = dst->nb[3]; + + cl_kernel kernel = backend_ctx->kernel_repeat; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &src0_ne0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &src0_ne1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &src0_ne2)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &src0_ne3)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &src0_nb0)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &src0_nb1)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &src0_nb2)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &src0_nb3)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &dst_ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &dst_ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &dst_ne2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dst_ne3)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &dst_nb0)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &dst_nb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &dst_nb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &dst_nb3)); + + size_t gws0 = dst_ne1 > 0 ? (size_t)dst_ne1 : 1; + size_t gws1 = dst_ne2 > 0 ? (size_t)dst_ne2 : 1; + size_t gws2 = dst_ne3 > 0 ? (size_t)dst_ne3 : 1; + + size_t global_work_size[] = { gws0, gws1, gws2 }; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, NULL, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, (size_t[3]){0,0,0}, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, NULL, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + if (backend_ctx->kernel_pad == nullptr) { + GGML_LOG_WARN("%s: pad kernel not available, skipping OpenCL execution.\n", __func__); + return; + } + + ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong off_src0 = extra_src0->offset + src0->view_offs; + cl_ulong off_dst = extra_dst->offset + dst->view_offs; + + const int s_ne0 = src0->ne[0]; + const int s_ne1 = src0->ne[1]; + const int s_ne2 = src0->ne[2]; + + const int d_ne0 = dst->ne[0]; + const int d_ne1 = dst->ne[1]; + const int d_ne2 = dst->ne[2]; + + cl_kernel kernel = backend_ctx->kernel_pad; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne0)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne1)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne2)); + + size_t lws0 = 64; + size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0; + + size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2 }; + size_t local_work_size[] = { lws0, 1, 1 }; + + size_t * local_work_size_ptr = local_work_size; + if (d_ne0 % lws0 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr ? local_work_size : (size_t[3]){0,0,0}, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0); + cl_kernel kernel = nullptr; + + if (mode == GGML_SCALE_MODE_NEAREST) { + kernel = backend_ctx->kernel_upscale; + if (kernel == nullptr) { + GGML_LOG_WARN("%s: nearest upscale kernel not available, skipping OpenCL execution.\n", __func__); + return; + } + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + kernel = backend_ctx->kernel_upscale_bilinear; + if (kernel == nullptr) { + GGML_LOG_WARN("%s: bilinear upscale kernel not available, skipping OpenCL execution.\n", __func__); + return; + } + } else { + GGML_LOG_WARN("%s: unsupported upscale mode %d, skipping OpenCL execution.\n", __func__, mode); + return; + } + + ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong off_src0 = extra_src0->offset + src0->view_offs; + cl_ulong off_dst = extra_dst->offset + dst->view_offs; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const int ne00_src = src0->ne[0]; + const int ne01_src = src0->ne[1]; + + const int ne10_dst = dst->ne[0]; + const int ne11_dst = dst->ne[1]; + const int ne12_dst = dst->ne[2]; + const int ne13_dst = dst->ne[3]; + + const float sf0 = (float)dst->ne[0] / src0->ne[0]; + const float sf1 = (float)dst->ne[1] / src0->ne[1]; + const float sf2 = (float)dst->ne[2] / src0->ne[2]; + const float sf3 = (float)dst->ne[3] / src0->ne[3]; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb03)); + + if (mode == GGML_SCALE_MODE_NEAREST) { + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne10_dst)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11_dst)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12_dst)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13_dst)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &sf0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &sf1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3)); + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00_src)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01_src)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10_dst)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11_dst)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12_dst)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13_dst)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf0)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf1)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(float), &sf2)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(float), &sf3)); + } + + + size_t dst_total_elements = (size_t)ne10_dst * ne11_dst * ne12_dst * ne13_dst; + if (dst_total_elements == 0) { + return; + } + size_t global_work_size[] = { dst_total_elements, 1, 1 }; + size_t local_work_size_pref = 256; + size_t local_work_size[] = { MIN(local_work_size_pref, dst_total_elements), 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (dst_total_elements % local_work_size[0] != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 1, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + size_t profiling_gws[3] = {global_work_size[0], 1, 1}; + size_t profiling_lws[3] = {local_work_size_ptr ? local_work_size[0] : 0, 1, 1}; + populateProfilingInfo(g_profiling_info.back(), evt, kernel, profiling_gws, profiling_lws, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 1, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + if (backend_ctx->kernel_concat_f32_contiguous == nullptr || backend_ctx->kernel_concat_f32_non_contiguous == nullptr) { + GGML_LOG_WARN("%s: concat kernels not available, skipping OpenCL execution.\n", __func__); + return; + } + + ggml_tensor_extra_cl * extra0_cl = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1_cl = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad_cl = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong off_src0 = extra0_cl->offset + src0->view_offs; + cl_ulong off_src1 = extra1_cl->offset + src1->view_offs; + cl_ulong off_dst = extrad_cl->offset + dst->view_offs; + + const int32_t dim = ((const int32_t *) dst->op_params)[0]; + GGML_ASSERT(dim >= 0 && dim <= 3); + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + if (dim == 3) { + + size_t nbytes_src0 = ggml_nbytes(src0); + size_t nbytes_src1 = ggml_nbytes(src1); + + CL_CHECK(clEnqueueCopyBuffer(queue, extra0_cl->data_device, extrad_cl->data_device, + off_src0, off_dst, nbytes_src0, 0, NULL, NULL)); + CL_CHECK(clEnqueueCopyBuffer(queue, extra1_cl->data_device, extrad_cl->data_device, + off_src1, off_dst + nbytes_src0, nbytes_src1, 0, NULL, NULL)); + } else { + + cl_kernel kernel = backend_ctx->kernel_concat_f32_contiguous; + size_t global_work_size[3]; + + for (int i3 = 0; i3 < dst->ne[3]; ++i3) { + cl_ulong current_off_src0 = off_src0 + (i3 * src0->nb[3]); + cl_ulong current_off_src1 = off_src1 + (i3 * src1->nb[3]); + cl_ulong current_off_dst = off_dst + (i3 * dst->nb[3]); + + int d_ne00 = src0->ne[0]; int d_ne01 = src0->ne[1]; int d_ne02 = src0->ne[2]; + int d_ne10 = src1->ne[0]; int d_ne11 = src1->ne[1]; int d_ne12 = src1->ne[2]; + int d_ne0 = dst->ne[0]; int d_ne1 = dst->ne[1]; int d_ne2 = dst->ne[2]; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), ¤t_off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), ¤t_off_src1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), ¤t_off_dst)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &d_ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &d_ne11)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &d_ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dim)); + + global_work_size[0] = d_ne0; + global_work_size[1] = d_ne1; + global_work_size[2] = d_ne2; + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, NULL, 0, NULL, NULL)); + } + } + } else { + cl_kernel kernel = backend_ctx->kernel_concat_f32_non_contiguous; + + long ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3]; + cl_ulong nb00 = src0->nb[0], nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3]; + + cl_ulong nb10 = src1->nb[0], nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3]; + + long d_ne0 = dst->ne[0], d_ne1 = dst->ne[1], d_ne2 = dst->ne[2], d_ne3 = dst->ne[3]; + cl_ulong d_nb0 = dst->nb[0], d_nb1 = dst->nb[1], d_nb2 = dst->nb[2], d_nb3 = dst->nb[3]; + + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_src1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &off_dst)); + + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(long), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(long), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(long), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(long), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(long), &d_ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(long), &d_ne1)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(long), &d_ne2)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(long), &d_ne3)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &d_nb0)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &d_nb1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &d_nb2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(cl_ulong), &d_nb3)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &dim)); + + size_t global_work_size_nc[] = { d_ne1 > 0 ? (size_t)d_ne1 : 1, + d_ne2 > 0 ? (size_t)d_ne2 : 1, + d_ne3 > 0 ? (size_t)d_ne3 : 1 }; + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size_nc, NULL, 0, NULL, NULL)); + } +} + +static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + if (backend_ctx->kernel_timestep_embedding == nullptr) { + GGML_LOG_WARN("%s: timestep_embedding kernel not available, skipping OpenCL execution.\n", __func__); + return; + } + + ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong off_src0 = extra_src0->offset + src0->view_offs; + cl_ulong off_dst = extra_dst->offset + dst->view_offs; + + const int logical_dim = dst->op_params[0]; + const int max_period = dst->op_params[1]; + const int dst_nb1_bytes = dst->nb[1]; + + cl_kernel kernel = backend_ctx->kernel_timestep_embedding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &dst_nb1_bytes)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &logical_dim)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &max_period)); + + size_t gws0 = (size_t)(((logical_dim + 1) / 2) + 1); + + size_t gws1 = (size_t)src0->ne[0]; + + size_t global_work_size[] = {gws0, gws1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_work_size, NULL, 0, NULL, &evt)); // Pass 2 for 2D problem + + g_profiling_info.emplace_back(); + size_t profiling_gws[3] = {global_work_size[0], global_work_size[1], 1}; + size_t profiling_lws[3] = {0,0,0}; // Reflects NULL LWS + populateProfilingInfo(g_profiling_info.back(), evt, kernel, profiling_gws, profiling_lws, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_work_size, NULL, 0, NULL, NULL)); // Pass 2 for 2D problem +#endif +} + static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -5667,6 +6375,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_sigmoid; break; + case GGML_UNARY_OP_TANH: + if (!any_on_device) { + return false; + } + func = ggml_cl_tanh; + break; default: return false; } break; @@ -5694,6 +6408,36 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_group_norm; break; + case GGML_OP_REPEAT: + if (!any_on_device) { + return false; + } + func = ggml_cl_repeat; + break; + case GGML_OP_PAD: + if (!any_on_device) { + return false; + } + ggml_cl_pad(backend, tensor->src[0], tensor); + return true; + case GGML_OP_UPSCALE: + if (!any_on_device) { + return false; + } + ggml_cl_upscale(backend, tensor->src[0], tensor); + return true; + case GGML_OP_CONCAT: + if (!any_on_device) { + return false; + } + func = ggml_cl_concat; + break; + case GGML_OP_TIMESTEP_EMBEDDING: + if (!any_on_device) { + return false; + } + ggml_cl_timestep_embedding(backend, tensor->src[0], tensor); + return true; case GGML_OP_MUL_MAT: if (!any_on_device && !ggml_cl_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) { return false; diff --git a/ggml/src/ggml-opencl/kernels/concat.cl b/ggml/src/ggml-opencl/kernels/concat.cl new file mode 100644 index 0000000000000..132758469c6fa --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/concat.cl @@ -0,0 +1,109 @@ +kernel void kernel_concat_f32_contiguous( + global const char * p_src0, ulong off_src0, + global const char * p_src1, ulong off_src1, + global char * p_dst, ulong off_dst, + int d_ne00, int d_ne01, int d_ne02, // src0->ne[0..2] for the slice + int d_ne10, int d_ne11, int d_ne12, // src1->ne[0..2] for the slice (d_ne1X must match d_ne0X on non-concat axes) + int d_ne0, int d_ne1, int d_ne2, // dst->ne[0..2] for the slice + int dim +) { + global const float * src0 = (global const float*)((global char*)p_src0 + off_src0); + global const float * src1 = (global const float*)((global char*)p_src1 + off_src1); + global float * dst = (global float*)((global char*)p_dst + off_dst); + + int i0 = get_global_id(0); // Index along dst's 0th dimension + int i1 = get_global_id(1); // Index along dst's 1st dimension + int i2 = get_global_id(2); // Index along dst's 2nd dimension + + if (i0 >= d_ne0 || i1 >= d_ne1 || i2 >= d_ne2) { + return; + } + + ulong dst_idx = (ulong)i2 * d_ne0 * d_ne1 + (ulong)i1 * d_ne0 + i0; + ulong src_idx; + + if (dim == 0) { + if (i0 < d_ne00) { // Data from src0 + src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; + dst[dst_idx] = src0[src_idx]; + } else { // Data from src1 + src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + (i0 - d_ne00); + dst[dst_idx] = src1[src_idx]; + } + } else if (dim == 1) { + if (i1 < d_ne01) { // Data from src0 + src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; + dst[dst_idx] = src0[src_idx]; + } else { // Data from src1 + src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)(i1 - d_ne01) * d_ne10 + i0; + dst[dst_idx] = src1[src_idx]; + } + } else if (dim == 2) { + if (i2 < d_ne02) { // Data from src0 + src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; + dst[dst_idx] = src0[src_idx]; + } else { // Data from src1 + + src_idx = (ulong)(i2 - d_ne02) * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + i0; + dst[dst_idx] = src1[src_idx]; + } + } +} + +kernel void kernel_concat_f32_non_contiguous( + global const char * p_src0, ulong off_src0, + global const char * p_src1, ulong off_src1, + global char * p_dst, ulong off_dst, + + long ne00, long ne01, long ne02, long ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, + + ulong nb10, ulong nb11, ulong nb12, ulong nb13, // Strides for src1 + + long d_ne0, long d_ne1, long d_ne2, long d_ne3, + ulong d_nb0, ulong d_nb1, ulong d_nb2, ulong d_nb3, + int dim +) { + global const char * src0_base = p_src0 + off_src0; + global const char * src1_base = p_src1 + off_src1; + global char * dst_base = p_dst + off_dst; + + long current_i1 = get_global_id(0); // Index for dst_dim_1 + long current_i2 = get_global_id(1); // Index for dst_dim_2 + long current_i3 = get_global_id(2); // Index for dst_dim_3 + + if (current_i1 >= d_ne1 || current_i2 >= d_ne2 || current_i3 >= d_ne3) { + return; + } + + global const float * x_val_ptr; + global float * y_val_ptr; + + for (long current_i0 = 0; current_i0 < d_ne0; ++current_i0) { + bool use_src0; + long s_i0 = current_i0, s_i1 = current_i1, s_i2 = current_i2, s_i3 = current_i3; + + if (dim == 0) { + use_src0 = (current_i0 < ne00); + if (!use_src0) { s_i0 = current_i0 - ne00; } + } else if (dim == 1) { + use_src0 = (current_i1 < ne01); + if (!use_src0) { s_i1 = current_i1 - ne01; } + } else if (dim == 2) { + use_src0 = (current_i2 < ne02); + if (!use_src0) { s_i2 = current_i2 - ne02; } + } else { // dim == 3 + use_src0 = (current_i3 < ne03); + if (!use_src0) { s_i3 = current_i3 - ne03; } + } + + if (use_src0) { + x_val_ptr = (global const float *)(src0_base + (ulong)s_i3*nb03 + (ulong)s_i2*nb02 + (ulong)s_i1*nb01 + (ulong)s_i0*nb00); + } else { + x_val_ptr = (global const float *)(src1_base + (ulong)s_i3*nb13 + (ulong)s_i2*nb12 + (ulong)s_i1*nb11 + (ulong)s_i0*nb10); + } + + y_val_ptr = (global float *)(dst_base + (ulong)current_i3*d_nb3 + (ulong)current_i2*d_nb2 + (ulong)current_i1*d_nb1 + (ulong)current_i0*d_nb0); + *y_val_ptr = *x_val_ptr; + } +} diff --git a/ggml/src/ggml-opencl/kernels/pad.cl b/ggml/src/ggml-opencl/kernels/pad.cl new file mode 100644 index 0000000000000..747fa7febcc74 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/pad.cl @@ -0,0 +1,30 @@ +kernel void kernel_pad( + global const void * src0_ptr, + ulong src0_offset, + global void * dst_ptr, + ulong dst_offset, + int s_ne0, int s_ne1, int s_ne2, + int d_ne0, int d_ne1, int d_ne2 +) { + global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset); + global float * dst = (global float *)((global char *)dst_ptr + dst_offset); + + int nidx = get_global_id(0); + int idx_d1 = get_group_id(1); + int idx_d2 = get_group_id(2); + + if (nidx >= d_ne0) { + return; + } + + int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1; + + bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2); + + if (in_src_bounds) { + int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1; + dst[dst_el_offset] = src0[src_el_offset]; + } else { + dst[dst_el_offset] = 0.0f; + } +} diff --git a/ggml/src/ggml-opencl/kernels/repeat.cl b/ggml/src/ggml-opencl/kernels/repeat.cl new file mode 100644 index 0000000000000..079498f5ab947 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/repeat.cl @@ -0,0 +1,39 @@ +kernel void kernel_repeat( + global const char * src0_data_in, + global char * dst_data_in, + ulong src0_offset, + ulong dst_offset, + int src0_ne0, int src0_ne1, int src0_ne2, int src0_ne3, + ulong src0_nb0, ulong src0_nb1, ulong src0_nb2, ulong src0_nb3, + int dst_ne0, int dst_ne1, int dst_ne2, int dst_ne3, + ulong dst_nb0, ulong dst_nb1, ulong dst_nb2, ulong dst_nb3 +) { + global const char * src0_data = src0_data_in + src0_offset; + global char * dst_data = dst_data_in + dst_offset; + + const int d3 = get_global_id(2); + const int d2 = get_global_id(1); + const int d1 = get_global_id(0); + + if (d3 >= dst_ne3 || d2 >= dst_ne2 || d1 >= dst_ne1) { + return; + } + + const int s3 = d3 % src0_ne3; + const int s2 = d2 % src0_ne2; + const int s1 = d1 % src0_ne1; + + const global char * p_src0_slice = src0_data + (ulong)s3*src0_nb3 + (ulong)s2*src0_nb2 + (ulong)s1*src0_nb1; + global char * p_dst_slice = dst_data + (ulong)d3*dst_nb3 + (ulong)d2*dst_nb2 + (ulong)d1*dst_nb1; + + for (int d0 = 0; d0 < dst_ne0; ++d0) { + // Determine source index for dimension 0 based on tiling/broadcasting. + const int s0 = d0 % src0_ne0; + + const global char * restrict current_src_el_ptr = p_src0_slice + (ulong)s0*src0_nb0; + global char * restrict current_dst_el_ptr = p_dst_slice + (ulong)d0*dst_nb0; + for (int k = 0; k < src0_nb0; ++k) { + current_dst_el_ptr[k] = current_src_el_ptr[k]; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/tanh.cl b/ggml/src/ggml-opencl/kernels/tanh.cl new file mode 100644 index 0000000000000..d9da86b148921 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/tanh.cl @@ -0,0 +1,63 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +kernel void kernel_tanh_f32_nd( + global void * p_src0_base, ulong off_src0_abs, + global void * p_dst_base, ulong off_dst_abs, + int ne00, int ne01, int ne02, int ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, + int ne10, int ne11, int ne12, int ne13, + ulong nb10, ulong nb11, ulong nb12, ulong nb13 +) { + int i0 = get_global_id(0); + int i1 = get_global_id(1); + int i2 = get_global_id(2); + + if (i0 < ne10 && i1 < ne11 && i2 < ne12) { + for (int i3 = 0; i3 < ne13; ++i3) { + ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; + global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + + ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; + global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + + *dst_val_ptr = tanh(*src_val_ptr); + } + } +} + +kernel void kernel_tanh_f16_nd( + global void * p_src0_base, ulong off_src0_abs, + global void * p_dst_base, ulong off_dst_abs, + int ne00, int ne01, int ne02, int ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, + int ne10, int ne11, int ne12, int ne13, + ulong nb10, ulong nb11, ulong nb12, ulong nb13 +) { + int i0 = get_global_id(0); + int i1 = get_global_id(1); + int i2 = get_global_id(2); + + if (i0 < ne10 && i1 < ne11 && i2 < ne12) { + for (int i3 = 0; i3 < ne13; ++i3) { + ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; + global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + + ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; + global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + + *dst_val_ptr = tanh(*src_val_ptr); + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/tsembd.cl b/ggml/src/ggml-opencl/kernels/tsembd.cl new file mode 100644 index 0000000000000..4b1107f70ba7a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/tsembd.cl @@ -0,0 +1,48 @@ +kernel void kernel_timestep_embedding( + global const void * p_timesteps, + ulong off_timesteps, + global void * p_dst, + ulong off_dst, + int dst_nb1_bytes, + int logical_dim, + int max_period +) { + int local_i; + int local_j; + int local_half_dim; + float local_timestep_val; + float local_freq; + float local_arg; + global float * local_embed_data_ptr; + global const float * local_timesteps_input_ptr; + global float * local_dst_output_base_ptr; + + local_timesteps_input_ptr = (global const float *)((global char *)p_timesteps + off_timesteps); + local_dst_output_base_ptr = (global float *)((global char *)p_dst + off_dst); + + local_i = get_global_id(1); + local_j = get_global_id(0); + + local_half_dim = logical_dim / 2; + local_embed_data_ptr = (global float *)((global char *)local_dst_output_base_ptr + local_i * dst_nb1_bytes); + + if (logical_dim % 2 != 0 && local_j == ((logical_dim + 1) / 2)) { + local_embed_data_ptr[logical_dim] = 0.0f; + } + + if (local_j >= local_half_dim) { + return; + } + + local_timestep_val = local_timesteps_input_ptr[local_i]; + + if (local_half_dim == 0) { + local_freq = 1.0f; + } else { + local_freq = exp(-log((float)max_period) * (float)local_j / (float)local_half_dim); + } + + local_arg = local_timestep_val * local_freq; + local_embed_data_ptr[local_j] = cos(local_arg); + local_embed_data_ptr[local_j + local_half_dim] = sin(local_arg); +} diff --git a/ggml/src/ggml-opencl/kernels/upscale.cl b/ggml/src/ggml-opencl/kernels/upscale.cl new file mode 100644 index 0000000000000..219d31dbb9248 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/upscale.cl @@ -0,0 +1,121 @@ +kernel void kernel_upscale( + global const void * p_src0, + ulong off_src0, + global void * p_dst, + ulong off_dst, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + float sf0, + float sf1, + float sf2, + float sf3 +) { + global const char * src_base = (global const char *)p_src0 + off_src0; + global float * dst_base = (global float *)((global char *)p_dst + off_dst); + + int index = get_global_id(0); + int dst_total_elements = ne10 * ne11 * ne12 * ne13; + + if (index >= dst_total_elements) { + return; + } + + int i10 = index % ne10; + int i11 = (index / ne10) % ne11; + int i12 = (index / (ne10 * ne11)) % ne12; + int i13 = index / (ne10 * ne11 * ne12); + + int i00 = (int)(i10 / sf0); + int i01 = (int)(i11 / sf1); + int i02 = (int)(i12 / sf2); + int i03 = (int)(i13 / sf3); + + ulong offset_src_element = (ulong)i03 * nb03 + (ulong)i02 * nb02 + (ulong)i01 * nb01 + (ulong)i00 * nb00; + global const float * src_element_ptr = (global const float *)(src_base + offset_src_element); + + dst_base[index] = *src_element_ptr; +} + +kernel void kernel_upscale_bilinear( + global const void * p_src0, + ulong off_src0, + global void * p_dst, + ulong off_dst, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne00_src, + int ne01_src, + int ne10_dst, + int ne11_dst, + int ne12_dst, + int ne13_dst, + float sf0, + float sf1, + float sf2, + float sf3 +) { + global const char * src_base = (global const char *)p_src0 + off_src0; + global float * dst_base = (global float *)((global char *)p_dst + off_dst); + + int index = get_global_id(0); + int dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + int i10_dst = index % ne10_dst; + int i11_dst = (index / ne10_dst) % ne11_dst; + int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + int i02_src = (int)(i12_dst / sf2); + int i03_src = (int)(i13_dst / sf3); + + const float pixel_offset = 0.5f; + + float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; + long y0_src = (long)floor(y_src_f); + long y1_src = y0_src + 1; + + y0_src = max(0L, min(y0_src, (long)ne01_src - 1)); + y1_src = max(0L, min(y1_src, (long)ne01_src - 1)); + + float dy = y_src_f - (float)y0_src; + dy = max(0.0f, min(dy, 1.0f)); + + float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset; + long x0_src = (long)floor(x_src_f); + long x1_src = x0_src + 1; + + x0_src = max(0L, min(x0_src, (long)ne00_src - 1)); + x1_src = max(0L, min(x1_src, (long)ne00_src - 1)); + + float dx = x_src_f - (float)x0_src; + dx = max(0.0f, min(dx, 1.0f)); + + global const float * p_a = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03); + global const float * p_b = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03); + global const float * p_c = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03); + global const float * p_d = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03); + + const float val_a = *p_a; + const float val_b = *p_b; + const float val_c = *p_c; + const float val_d = *p_d; + + float result = val_a * (1.0f - dx) * (1.0f - dy) + + val_b * dx * (1.0f - dy) + + val_c * (1.0f - dx) * dy + + val_d * dx * dy; + + dst_base[index] = result; +} diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index a2e26124802b2..2a0045bcc158e 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -13,7 +13,7 @@ elseif(SUPPORTS_SYCL) If you expected the oneAPI Release compiler, please install oneAPI & source it, like: source /opt/intel/oneapi/setvars.sh") else() - message(FATAL_ERROR, "C++ compiler lacks SYCL support.") + message(FATAL_ERROR "C++ compiler lacks SYCL support.") endif() message(STATUS "SYCL found") #todo: AOT @@ -170,7 +170,7 @@ else() target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_NVIDIA) elseif (GGML_SYCL_TARGET STREQUAL "AMD") if (NOT GGML_SYCL_DEVICE_ARCH) - message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.") + message(FATAL_ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.") endif() target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_rocblas) target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa") diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index bcd2ea5366f76..78513114c55f3 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1434,6 +1434,59 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, reinterpret_cast(y[ib].ds.y()) = sum; } +template +static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor, + const int kx, const int kx_padded, const sycl::nd_item<1> & it) { + /* + Quantizes and reorders the resultant q8 tensor in a per row fashion + Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values + */ + + auto subgroup_id = it.get_group(0); + auto wi_id = it.get_local_id(0); + + const int num_blocks_per_row = kx / QK8_1; + auto row = subgroup_id / num_blocks_per_row; + auto col = subgroup_id % num_blocks_per_row; + + auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1); + auto col_offset = QK8_1 * col + wi_id * ElementsPerWI; + + auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset); + auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2)); + + sycl::vec wi_f32_vals; + sycl::vec quantized_values; + + auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id; + wi_f32_vals = *reinterpret_cast *>(x + float_ptr_offset); + + float sum = 0.0f; + float amax = 0.0f; + +#pragma unroll(ElementsPerWI) + for (int i = 0; i < ElementsPerWI; i++) { + sum += wi_f32_vals[i]; + amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i])); + quantized_values[i] = 0; + } + sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus()); + amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum()); + float d = amax == 0 ? 1 : amax / 127; + +#pragma unroll(ElementsPerWI) + for (int i = 0; i < ElementsPerWI; i++) { + quantized_values[i] = sycl::round(wi_f32_vals[i] / d); + } + + d = amax == 0 ? 0 : d; + + *reinterpret_cast *>(quant_ptr) = quantized_values; + if (wi_id == 0) { + *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum)); + } +} + static void mul_mat_p021_f16_f32( const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y, @@ -1718,23 +1771,30 @@ static void pool2d_nchw_kernel( o_ptr[cur_oh * ow + cur_ow] = res; } -static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx, - const int ky, const int kx_padded, - queue_ptr stream) { - const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE; - const sycl::range<3> num_blocks(1, ky, block_num_x); - int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE; - static_assert(QK8_1 % WARP_SIZE == 0); - const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE); - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); +static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded, + bool reorder_q8_tensor, queue_ptr stream) { + if (reorder_q8_tensor) { + auto local_range = std::size_t(WARP_SIZE); + auto num_quant_blocks = ky * (kx / QK8_1); + auto global_range = num_quant_blocks * local_range; + stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), + [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + quantize_and_reorder_q8_1(x, vy, kx, kx_padded, it); + }); + } else { + const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE; + const sycl::range<3> num_blocks(1, ky, block_num_x); + int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE; + static_assert(QK8_1 % WARP_SIZE == 0); + const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE); + { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - stream->parallel_for( - sycl::nd_range<3>(num_blocks * block_size, block_size), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - quantize_q8_1(x, vy, kx, kx_padded, item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + quantize_q8_1(x, vy, kx, kx_padded, item_ct1); + }); + } } } @@ -2446,9 +2506,10 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs); if (src1_on_device && src1_is_contiguous) { + bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder; scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, /*num_src=*/2, " : converting src1 to Q8_1"); - quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream); + quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream); /* DPCT1010:90: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to @@ -2554,7 +2615,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten if (convert_src1_to_q8_1 && !src1_is_contiguous) { scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, /*num_src=*/2, " : converting src1 to Q8_1"); - quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream); + quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false, stream); /* DPCT1010:92: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index cb70f83a4f9a6..80c780b209998 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -29,8 +29,6 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r static_assert(blocks_per_subgroup > 0); static_assert(block_elements_per_subgroup > 0); - const block_q8_1 * y = (const block_q8_1 *) vy; - float partial_sum = 0.0f; for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) { const int ibx = row * blocks_per_row + i; // x block index @@ -40,13 +38,15 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r // Y block index that aligns with ibx const int iby = i * block_type::block_to_q8_1_ratio(); + const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1; + const sycl::half2* q8_1_ds_ptr = (const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2)); #pragma unroll for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) { // x block quant index when casting the quants to int const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup); - partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks); + partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs, nblocks); } } diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index ed3699313466b..fa258e4d4d106 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -285,21 +285,21 @@ template <> struct reorder_vec_dot_q_sycl { } __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, - const block_q8_1 * __restrict__ bq8_1, const int & iqs, int /* nblocks */) { + const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int /* nblocks */) { const uint8_t * bq4_0 = static_cast(vbq) + ibx_offset; const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset)); int v[q4_0_traits::vdr_mmvq]; int u[2 * q4_0_traits::vdr_mmvq]; -#pragma unroll +#pragma unroll for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) { v[i] = get_int_from_uint8(bq4_0, iqs + i); - u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + q4_0_traits::qi); + u[2 * i + 0] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i); + u[2 * i + 1] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i + q4_0_traits::qi); } - return vec_dot_q4_0_q8_1_impl(v, u, d, bq8_1->ds); + return vec_dot_q4_0_q8_1_impl(v, u, d, *q8_1_ds); }; }; @@ -347,7 +347,7 @@ template <> struct reorder_vec_dot_q_sycl { using q4_k_traits = typename q4_k_block::traits; float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, - const block_q8_1 * __restrict__ bq8_1, const int & iqs, int nblocks) { + const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int nblocks) { const int ib = ibx_offset / (QK_K / 2); const uint8_t * base = static_cast(vbq); @@ -360,7 +360,38 @@ template <> struct reorder_vec_dot_q_sycl { const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); const uint16_t * scales = (const uint16_t *) scs; - return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs); + int v[2]; + int u[2 * QR4_K]; + float d8[QR4_K]; + + v[0] = q4[0]; + v[1] = q4[4]; + + uint16_t aux[2]; + const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2; + if (j < 2) { + aux[0] = scales[j + 0] & 0x3f3f; + aux[1] = scales[j + 2] & 0x3f3f; + } else { + aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2); + aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2); + } + + const uint8_t * sc = (const uint8_t *) aux; + const uint8_t * m = sc + 2; + + for (int i = 0; i < QR4_K; ++i) { + const int8_t* quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1; + sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i); + + d8[i] = ds_values[0]; + + const int * q8 = (const int *) quant_base_ptr + ((iqs / 2) % 4); + u[2 * i + 0] = q8[0]; + u[2 * i + 1] = q8[4]; + } + + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, *dms, d8); } }; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f6717b1e34b0c..ff9dc871a3f9d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1652,7 +1652,7 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t D, uint32_ return {64, 32}; } return {64, 64}; -}; +} static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 25bfa04640b9f..6fc9949fdac1c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -133,7 +133,7 @@ static void ggml_print_backtrace_symbols(void) { } #endif -static void ggml_print_backtrace(void) { +void ggml_print_backtrace(void) { const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE"); if (GGML_NO_BACKTRACE) { return; @@ -160,6 +160,10 @@ static void ggml_print_backtrace(void) { const int parent_pid = getpid(); const int child_pid = fork(); if (child_pid < 0) { // error +#if defined(__linux__) + close(lock[1]); + close(lock[0]); +#endif return; } else if (child_pid == 0) { // child char attach[32]; @@ -167,6 +171,7 @@ static void ggml_print_backtrace(void) { #if defined(__linux__) close(lock[1]); (void) !read(lock[0], lock, 1); + close(lock[0]); #endif // try gdb execlp("gdb", "gdb", "--batch", @@ -195,7 +200,7 @@ static void ggml_print_backtrace(void) { } } #else -static void ggml_print_backtrace(void) { +void ggml_print_backtrace(void) { // platform not supported } #endif @@ -235,6 +240,8 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) { } #endif +// ggml_print_backtrace is registered with std::set_terminate by ggml.cpp + // // logging // diff --git a/ggml/src/ggml.cpp b/ggml/src/ggml.cpp new file mode 100644 index 0000000000000..0d388d45536d1 --- /dev/null +++ b/ggml/src/ggml.cpp @@ -0,0 +1,26 @@ +#include "ggml-impl.h" + +#include +#include + +static std::terminate_handler previous_terminate_handler; + +GGML_NORETURN static void ggml_uncaught_exception() { + ggml_print_backtrace(); + if (previous_terminate_handler) { + previous_terminate_handler(); + } + abort(); // unreachable unless previous_terminate_handler was nullptr +} + +static bool ggml_uncaught_exception_init = []{ + const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE"); + if (GGML_NO_BACKTRACE) { + return false; + } + const auto prev{std::get_terminate()}; + GGML_ASSERT(prev != ggml_uncaught_exception); + previous_terminate_handler = prev; + std::set_terminate(ggml_uncaught_exception); + return true; +}(); diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index 8667a80bd0685..a0a318a29f5b9 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -347,11 +347,28 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par int64_t n_tensors = 0; if (ok && gr.read(ctx->version)) { - if (ctx->version == 1) { + if (ok && ctx->version == 0) { + GGML_LOG_ERROR("%s: bad GGUF version: %" PRIu32 "\n", __func__, ctx->version); + ok = false; + } + + /* + * bit layout is different when reading non-native endian models. + * assuming that the GGUF version is 3, the non-native endian model + * would read it as 0x30000000. we can use the AND operation against + * the last 4 hexadecimal digits to check if the model is the same + * endianness as the host system. + */ + if (ok && (ctx->version & 0x0000FFFF) == 0x00000000) { + GGML_LOG_ERROR("%s: failed to load model: this GGUF file version %" PRIu32 " is extremely large, is there a mismatch between the host and model endianness?\n", __func__, ctx->version); + ok = false; + } + + if (ok && ctx->version == 1) { GGML_LOG_ERROR("%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__); ok = false; } - if (ctx->version > GGUF_VERSION) { + if (ok && ctx->version > GGUF_VERSION) { GGML_LOG_ERROR("%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n", __func__, ctx->version, GGUF_VERSION); ok = false; diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 88a1219f57908..aa0fb8fb02001 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -06b715f4c170232af261425240914fa49c44f982 +94a83ba5a725ae2aee79df75dd99b2119d0478cc diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e93e8f9cc0d55..d20bd4fe2333d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -21,6 +21,9 @@ add_library(llama llama-impl.cpp llama-io.cpp llama-kv-cache.cpp + llama-kv-cache-unified.cpp + llama-kv-cache-unified-iswa.cpp + llama-kv-cache-recurrent.cpp llama-memory.cpp llama-mmap.cpp llama-model-loader.cpp diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b30f6fb4f4145..727e119e334f6 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -3,7 +3,10 @@ #include "llama-impl.h" #include "llama-batch.h" #include "llama-cparams.h" -#include "llama-kv-cache.h" + +#include "llama-kv-cache-unified.h" +#include "llama-kv-cache-unified-iswa.h" +#include "llama-kv-cache-recurrent.h" #include #include diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp new file mode 100644 index 0000000000000..641eab2f316ce --- /dev/null +++ b/src/llama-kv-cache-recurrent.cpp @@ -0,0 +1,1132 @@ +#include "llama-kv-cache-recurrent.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-model.h" + +#include +#include +#include +#include +#include + +// +// llama_kv_cache_recurrent +// + +llama_kv_cache_recurrent::llama_kv_cache_recurrent( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { + const int32_t n_layer = hparams.n_layer; + + LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n", + __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer); + + head = 0; + size = kv_size; + used = 0; + + cells.clear(); + cells.resize(kv_size); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map[buft] = ctx; + ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + k_l.reserve(n_layer); + v_l.reserve(n_layer); + + for (int i = 0; i < n_layer; i++) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + + const char * dev_name = "CPU"; + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + + if (offload) { + auto * dev = model.dev_layer(i); + buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); + } + + LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name); + + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + throw std::runtime_error("failed to create ggml context for kv cache"); + } + + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + k_l.push_back(k); + v_l.push_back(v); + } + + // allocate tensors and initialize the buffers to avoid NaNs in the padding + for (auto it : ctx_map) { + auto * buft = it.first; + auto * ctx = it.second; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + throw std::runtime_error("failed to allocate buffer for kv cache"); + } + ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + bufs.emplace_back(buf); + } + + { + const size_t memory_size_k = size_k_bytes(); + const size_t memory_size_v = size_v_bytes(); + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } +} + +void llama_kv_cache_recurrent::clear() { + for (int32_t i = 0; i < (int32_t) size; ++i) { + cells[i].pos = -1; + cells[i].seq_id.clear(); + cells[i].src = -1; + cells[i].tail = -1; + } + head = 0; + used = 0; + + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } +} + +bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + uint32_t new_head = size; + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // models like Mamba or RWKV can't have a state partially erased + if (seq_id >= (int64_t) size) { + // could be fatal + return false; + } + if (0 <= seq_id) { + int32_t & tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + const kv_cell & cell = cells[tail_id]; + // partial intersection is invalid + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + // invalidate tails which will be cleared + if (p0 <= cell.pos && cell.pos < p1) { + tail_id = -1; + } + } + } else { + // seq_id is negative, then the range should include everything or nothing + if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + return false; + } + } + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].pos >= p0 && cells[i].pos < p1) { + if (seq_id < 0) { + cells[i].seq_id.clear(); + } else if (cells[i].has_seq_id(seq_id)) { + cells[i].seq_id.erase(seq_id); + } else { + continue; + } + if (cells[i].is_empty()) { + // keep count of the number of used cells + if (cells[i].pos >= 0) { + used--; + } + cells[i].pos = -1; + cells[i].src = -1; + if (new_head == size) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != size && new_head < head) { + head = new_head; + } + + return true; +} + +void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { + kv_cell & tail_src = cells[seq_id_src]; + kv_cell & tail_dst = cells[seq_id_dst]; + if (tail_dst.tail >= 0) { + // clear destination seq_id if it wasn't empty + kv_cell & cell_dst = cells[tail_dst.tail]; + + cell_dst.seq_id.erase(seq_id_dst); + tail_dst.tail = -1; + if (cell_dst.seq_id.empty()) { + cell_dst.pos = -1; + cell_dst.src = -1; + used -= 1; + } + } + if (tail_src.tail >= 0) { + kv_cell & cell_src = cells[tail_src.tail]; + + cell_src.seq_id.insert(seq_id_dst); + tail_dst.tail = tail_src.tail; + } + } +} + +void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) { + uint32_t new_head = size; + + for (uint32_t i = 0; i < size; ++i) { + if ((llama_seq_id) i != seq_id) { + cells[i].tail = -1; + } + + if (!cells[i].has_seq_id(seq_id)) { + if (cells[i].pos >= 0) { + used--; + } + + cells[i].pos = -1; + cells[i].src = -1; + cells[i].seq_id.clear(); + + if (new_head == size){ + new_head = i; + } + } else { + cells[i].seq_id.clear(); + cells[i].seq_id.insert(seq_id); + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != size && new_head < head) { + head = new_head; + } +} + +void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + if (shift == 0) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the + if (p0 == p1) { + return; + } + + // for Mamba-like or RWKV models, only the pos needs to be shifted + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + kv_cell & cell = cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos += shift; + } + } + } +} + +void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) { + return; + } + + // for Mamba-like or RWKV models, only the pos needs to be changed + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + kv_cell & cell = cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos /= d; + } + } + } +} + +llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const { + llama_pos result = std::numeric_limits::max(); + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + result = std::min(result, cells[i].pos); + } + } + + if (result == std::numeric_limits::max()) { + result = -1; + } + + return result; +} + +llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { + llama_pos result = -1; + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + result = std::max(result, cells[i].pos); + } + } + + return result; +} + +llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { + GGML_UNUSED(embd_pooled); + + auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); + + std::vector ubatches; + + while (sbatch.n_tokens > 0) { + llama_ubatch ubatch; + + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = sbatch.split_seq(n_ubatch); + } else { + ubatch = sbatch.split_equal(n_ubatch); + } + + ubatches.push_back(ubatch); + } + + if (!prepare(ubatches)) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches)); +} + +llama_memory_state_ptr llama_kv_cache_recurrent::init_full() { + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); +} + +bool llama_kv_cache_recurrent::prepare(const std::vector & ubatches) { + // simply remember the full state because it is very small for this type of cache + // TODO: optimize + auto org_cells = cells; + auto org_used = used; + auto org_head = head; + + bool success = true; + + // TODO: here we have to verify that all ubatches can fit in the cells + // however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells + // during the compute of each ubatch. to reproduce, uncomment the following loop and run: + // + // $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8 + // + // recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed + // + GGML_UNUSED(ubatches); + //for (const auto & ubatch : ubatches) { + // if (!find_slot(ubatch)) { + // success = false; + // break; + // } + //} + + // restore the original state + cells = std::move(org_cells); + used = org_used; + head = org_head; + + return success; +} + +bool llama_kv_cache_recurrent::update(llama_context & lctx) { + GGML_UNUSED(lctx); + // noop + return false; +} + +void llama_kv_cache_recurrent::defrag_sched(float thold) { + GGML_UNUSED(thold); + // noop +} + +bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { + const uint32_t n_tokens = ubatch.n_tokens; + const uint32_t n_seqs = ubatch.n_seqs; + + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head > used + 2*n_tokens) { + head = 0; + } + + // For recurrent state architectures (like Mamba or RWKV), + // each cache cell can store the state for a whole sequence. + // A slot should be always be contiguous. + + // can only process batches with an equal number of new tokens in each sequence + GGML_ASSERT(ubatch.equal_seqs); + + int32_t min = size - 1; + int32_t max = 0; + + // everything should fit if all seq_ids are smaller than the max + for (uint32_t s = 0; s < n_seqs; ++s) { + const uint32_t n_seq_id = ubatch.n_seq_id[s]; + for (uint32_t j = 0; j < n_seq_id; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + + if (seq_id < 0 || (uint32_t) seq_id >= size) { + // too big seq_id + // TODO: would it be possible to resize the cache instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max); + return false; + } + if (j > 0) { + kv_cell & seq = cells[seq_id]; + if (seq.tail >= 0) { + kv_cell & cell = cells[seq.tail]; + // clear cells from seq_ids that become shared + // (should not normally happen, but let's handle it anyway) + cell.seq_id.erase(seq_id); + seq.tail = -1; + if (cell.seq_id.empty()) { + cell.pos = -1; + cell.src = -1; + used -= 1; + } + } + } + } + } + +#ifndef NDEBUG + { + std::vector tails_verif; + tails_verif.assign(size, -1); + for (uint32_t i = 0; i < size; ++i) { + kv_cell & cell = cells[i]; + for (llama_seq_id seq_id : cell.seq_id) { + if (tails_verif[seq_id] != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); + } + tails_verif[seq_id] = i; + } + } + for (uint32_t i = 0; i < size; ++i) { + if (tails_verif[i] != cells[i].tail) { + LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]); + } + } + } +#endif + + // find next empty cell + uint32_t next_empty_cell = head; + + for (uint32_t i = 0; i < size; ++i) { + if (next_empty_cell >= size) { next_empty_cell -= size; } + kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + + // find usable cell range + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + kv_cell & seq_meta = cells[seq_id]; + bool has_cell = false; + if (seq_meta.tail >= 0) { + kv_cell & cell = cells[seq_meta.tail]; + GGML_ASSERT(cell.has_seq_id(seq_id)); + // does this seq_id "own" the cell? + if (cell.seq_id.size() == 1) { has_cell = true; } + } + if (!has_cell) { + kv_cell & empty_cell = cells[next_empty_cell]; + GGML_ASSERT(empty_cell.is_empty()); + // copy old tail into the empty cell + if (seq_meta.tail >= 0) { + kv_cell & orig_cell = cells[seq_meta.tail]; + empty_cell.pos = orig_cell.pos; + empty_cell.src = orig_cell.src; + orig_cell.seq_id.erase(seq_id); + empty_cell.seq_id.insert(seq_id); // will be overwritten + } + seq_meta.tail = next_empty_cell; + // find next empty cell + if (s + 1 < n_seqs) { + next_empty_cell += 1; + for (uint32_t i = 0; i < size; ++i) { + if (next_empty_cell >= size) { next_empty_cell -= size; } + kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + } + } + if (min > seq_meta.tail) { min = seq_meta.tail; } + if (max < seq_meta.tail) { max = seq_meta.tail; } + } + + // gather and re-order + for (uint32_t s = 0; s < n_seqs; ++s) { + int32_t dst_id = s + min; + int32_t src_id = cells[ubatch.seq_id[s][0]].tail; + if (dst_id != src_id) { + kv_cell & dst_cell = cells[dst_id]; + kv_cell & src_cell = cells[src_id]; + + std::swap(dst_cell.pos, src_cell.pos); + std::swap(dst_cell.src, src_cell.src); + std::swap(dst_cell.seq_id, src_cell.seq_id); + + // swap tails (assuming they NEVER overlap) + for (const llama_seq_id seq_id : src_cell.seq_id) { + cells[seq_id].tail = src_id; + } + for (const llama_seq_id seq_id : dst_cell.seq_id) { + cells[seq_id].tail = dst_id; + } + } + } + + // update the pos of the used seqs + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + int32_t cell_id = s + min; + kv_cell & cell = cells[cell_id]; + + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", + __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); + } + cell.pos = last_pos; + cell.seq_id.clear(); + for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + cell.seq_id.insert(seq_id); + cells[seq_id].tail = cell_id; + } + } + + // allow getting the range of used cells, from head to head + n + head = min; + n = max - min + 1; + used = std::count_if(cells.begin(), cells.end(), + [](const kv_cell & cell){ return !cell.is_empty(); }); + + // sanity check + return n >= n_seqs; +} + +bool llama_kv_cache_recurrent::get_can_shift() const { + return false; +} + +int32_t llama_kv_cache_recurrent::s_copy(int i) const { + const uint32_t cell_id = i + head; + + ////////////////////////////////////////////// + // TODO: this should not mutate the KV cache ! + kv_cell & cell = const_cast(cells[cell_id]); + + // prevent out-of-bound sources + if (cell.src < 0 || (uint32_t) cell.src >= size) { + cell.src = cell_id; + } + + int32_t res = cell.src; + + // TODO: do not mutate the KV cache + // ensure copy only happens once + if (cell.src != (int32_t) cell_id) { + cell.src = cell_id; + } + + return res; +} + +float llama_kv_cache_recurrent::s_mask(int i) const { + const uint32_t cell_id = i + head; + + ////////////////////////////////////////////// + // TODO: this should not mutate the KV cache ! + kv_cell & cell = const_cast(cells[cell_id]); + + float res = (float) (cell.src >= 0); + + // only clear once + if (cell.src < 0) { + cell.src = cell_id; + } + + return res; +} + +size_t llama_kv_cache_recurrent::total_size() const { + size_t size = 0; + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; +} + +size_t llama_kv_cache_recurrent::size_k_bytes() const { + size_t size_k_bytes = 0; + + for (const auto & k : k_l) { + size_k_bytes += ggml_nbytes(k); + } + + return size_k_bytes; +} + +size_t llama_kv_cache_recurrent::size_v_bytes() const { + size_t size_v_bytes = 0; + + for (const auto & v : v_l) { + size_v_bytes += ggml_nbytes(v); + } + + return size_v_bytes; +} + +void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = size; + for (uint32_t i = 0; i < size; ++i) { + const auto & cell = cells[i]; + if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { + ++cell_count; + if (cell_range_begin == size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = size; + } + } + } + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, size); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + + io.write(&cell_count, sizeof(cell_count)); + + state_write_meta(io, cell_ranges, seq_id); + state_write_data(io, cell_ranges); +} + +void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); + + bool res = true; + + res = res && state_read_meta(io, cell_count, seq_id); + res = res && state_read_data(io, cell_count); + + if (!res) { + if (seq_id == -1) { + clear(); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); + } +} + +void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = cells[i]; + const llama_pos pos = cell.pos; + const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; + + io.write(&pos, sizeof(pos)); + io.write(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id) { + for (auto seq_id : cell.seq_id) { + io.write(&seq_id, sizeof(seq_id)); + } + } + } + } +} + +void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { + const uint32_t v_trans = 0; + const uint32_t n_layer = hparams.n_layer; + + io.write(&v_trans, sizeof(v_trans)); + io.write(&n_layer, sizeof(n_layer)); + + std::vector tmp_buf; + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Write key type + const int32_t k_type_i = (int32_t)k_l[il]->type; + io.write(&k_type_i, sizeof(k_type_i)); + + // Write row size of key + const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + io.write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * k_size_row; + io.write_tensor(k_l[il], range.first * k_size_row, buf_size); + } + } + + if (!v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write row size of value + const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + io.write(&v_size_row, sizeof(v_size_row)); + + // Read each range of cells of v_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * v_size_row; + io.write_tensor(v_l[il], range.first * v_size_row, buf_size); + } + } + } else { + // When v is transposed, we also need the element size and get the element ranges from each row + const uint32_t kv_size = size; + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const uint32_t v_size_el = ggml_type_size(v_l[il]->type); + io.write(&v_size_el, sizeof(v_size_el)); + + // Write GQA embedding size + io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + const size_t buf_size = range_size * v_size_el; + io.write_tensor(v_l[il], src_offset, buf_size); + } + } + } + } +} + +bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { + if (dest_seq_id != -1) { + // single sequence + + seq_rm(dest_seq_id, -1, -1); + + llama_sbatch sbatch; + llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + + batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; + batch.n_seqs = 1; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 0) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + batch.pos[i] = pos; + } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &dest_seq_id; + + if (!find_slot(batch)) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + + // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(head + cell_count <= size); + GGML_ASSERT(cells[head].pos == batch.pos[0]); + GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); + GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); + } else { + // whole KV cache restore + + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } + + clear(); + + for (uint32_t i = 0; i < cell_count; ++i) { + kv_cell & cell = cells[i]; + + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + cell.pos = pos; + + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + + // TODO: llama_kv_cache_recurrent should have a notion of max sequences + //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + if (seq_id < 0) { + //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); + return false; + } + + cell.seq_id.insert(seq_id); + + int32_t & tail = cells[seq_id].tail; + if (tail != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); + return false; + } + tail = i; + } + } + + head = 0; + used = cell_count; + } + + for (uint32_t i = 0; i < cell_count; ++i) { + uint32_t cell_id = head + i; + // make sure the recurrent states will keep their restored state + cells[cell_id].src = cell_id; + } + + return true; +} + +bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) { + uint32_t v_trans; + uint32_t n_layer; + io.read_to(&v_trans, sizeof(v_trans)); + io.read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != hparams.n_layer) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + return false; + } + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); + return false; + } + if (false != (bool) v_trans) { + LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); + return false; + } + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Read type of key + int32_t k_type_i_ref; + io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + const int32_t k_type_i = (int32_t) k_l[il]->type; + if (k_type_i != k_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t k_size_row_ref; + io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + if (k_size_row != k_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + } + } + + if (!v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read row size of value + uint64_t v_size_row_ref; + io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read element size of value + uint32_t v_size_el_ref; + io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + const size_t v_size_el = ggml_type_size(v_l[il]->type); + if (v_size_el != v_size_el_ref) { + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); + return false; + } + + // Read GQA embedding size + uint32_t n_embd_v_gqa_ref; + io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + return false; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (head + j * size) * v_size_el; + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } + } + } + + return true; +} + +// +// llama_kv_cache_recurrent_state +// + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) { +} + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv, + llama_sbatch sbatch, + std::vector ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {} + +llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default; + +bool llama_kv_cache_recurrent_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_recurrent_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + kv->find_slot(ubatches[i_next]); + + return true; +} + +std::vector & llama_kv_cache_recurrent_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_recurrent_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +uint32_t llama_kv_cache_recurrent_state::get_n_kv() const { + return is_full ? kv->size : kv->n; +} + +uint32_t llama_kv_cache_recurrent_state::get_head() const { + return is_full ? 0 : kv->head; +} + +uint32_t llama_kv_cache_recurrent_state::get_size() const { + return kv->size; +} + +ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const { + return kv->k_l[il]; +} + +ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const { + return kv->v_l[il]; +} + +int32_t llama_kv_cache_recurrent_state::s_copy(int i) const { + return kv->s_copy(i); +} + +float llama_kv_cache_recurrent_state::s_mask(int i) const { + return kv->s_mask(i); +} diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h new file mode 100644 index 0000000000000..a178ae85c146a --- /dev/null +++ b/src/llama-kv-cache-recurrent.h @@ -0,0 +1,191 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cache.h" + +#include +#include + +// +// llama_kv_cache_recurrent +// + +// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i +// see the implementation of llama_kv_cache_unified_state_i for an example how to do it +class llama_kv_cache_recurrent : public llama_kv_cache { +public: + llama_kv_cache_recurrent( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max); + + ~llama_kv_cache_recurrent() = default; + + // + // llama_memory_i + // + + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // + // llama_kv_cache + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; + + llama_memory_state_ptr init_full() override; + + bool update(llama_context & lctx) override; + + void defrag_sched(float thold) override; + + bool prepare(const std::vector & ubatches); + + // find a contiguous slot of kv cells and emplace the ubatch there + bool find_slot(const llama_ubatch & ubatch); + + bool get_can_shift() const override; + + // TODO: temporary methods - they are not really const as they do const_cast<>, fix this + int32_t s_copy(int i) const; + float s_mask(int i) const; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) + uint32_t size = 0; // total number of cells, shared across all sequences + uint32_t used = 0; // used cells (i.e. at least one seq_id) + + // computed before each graph build + uint32_t n = 0; + + // TODO: optimize for recurrent state needs + struct kv_cell { + llama_pos pos = -1; + int32_t src = -1; // used to copy states + int32_t tail = -1; + + std::set seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } + + bool is_empty() const { + return seq_id.empty(); + } + + bool is_same_seq(const kv_cell & other) const { + return seq_id == other.seq_id; + } + }; + + std::vector cells; + + std::vector k_l; // per layer + std::vector v_l; + +private: + //const llama_model & model; + const llama_hparams & hparams; + + const uint32_t n_seq_max = 1; + + std::vector ctxs; + std::vector bufs; + + size_t total_size() const; + + size_t size_k_bytes() const; + size_t size_v_bytes() const; + + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; + + bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t cell_count); +}; + +class llama_kv_cache_recurrent_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_recurrent_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv); + + // used to create a state from a batch + llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv, + llama_sbatch sbatch, + std::vector ubatches); + + virtual ~llama_kv_cache_recurrent_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_recurrent_state specific API + // + + uint32_t get_n_kv() const; + uint32_t get_head() const; + uint32_t get_size() const; + + ggml_tensor * get_k_l(int32_t il) const; + ggml_tensor * get_v_l(int32_t il) const; + + int32_t s_copy(int i) const; + float s_mask(int i) const; + +private: + const llama_memory_status status; + + llama_kv_cache_recurrent * kv; + + llama_sbatch sbatch; + + size_t i_next = 0; + + std::vector ubatches; + + // + // data needed for building the compute graph for the current ubatch: + // TODO: extract all the state like `head` and `n` here + // + + const bool is_full = false; +}; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp new file mode 100644 index 0000000000000..0eb0456343546 --- /dev/null +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -0,0 +1,249 @@ +#include "llama-kv-cache-unified-iswa.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-model.h" + +#include +#include + +// +// llama_kv_cache_unified_iswa +// + +llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool swa_full, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_ubatch, + uint32_t n_pad) : hparams(model.hparams) { + llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; + llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; + + const uint32_t size_base = kv_size; + + uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad)); + + // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size + if (swa_full) { + LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + + size_swa = size_base; + } + + LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); + + kv_base = std::make_unique( + model, std::move(filter_base), type_k, type_v, + v_trans, offload, size_base, n_seq_max, n_pad, + 0, LLAMA_SWA_TYPE_NONE); + + LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); + + kv_swa = std::make_unique( + model, std::move(filter_swa), type_k, type_v, + v_trans, offload, size_swa, n_seq_max, n_pad, + hparams.n_swa, hparams.swa_type); +} + +void llama_kv_cache_unified_iswa::clear() { + kv_base->clear(); + kv_swa ->clear(); +} + +bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + bool res = true; + + res = res & kv_base->seq_rm(seq_id, p0, p1); + res = res & kv_swa ->seq_rm(seq_id, p0, p1); + + return res; +} + +void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) { + kv_base->seq_keep(seq_id); + kv_swa ->seq_keep(seq_id); +} + +void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_base->seq_add(seq_id, p0, p1, shift); + kv_swa ->seq_add(seq_id, p0, p1, shift); +} + +void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + kv_base->seq_div(seq_id, p0, p1, d); + kv_swa ->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const { + // the base cache is a superset of the SWA cache, so we can just check the SWA cache + return kv_swa->seq_pos_min(seq_id); +} + +llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { + return kv_swa->seq_pos_max(seq_id); +} + +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { + GGML_UNUSED(embd_pooled); + + // TODO: if we fail with split_simple, we should attempt different splitting strategies + // but to do that properly, we first have to refactor the batches to be more flexible + + auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); + + std::vector ubatches; + + while (sbatch.n_tokens > 0) { + auto ubatch = sbatch.split_simple(n_ubatch); + + ubatches.push_back(ubatch); + } + + auto heads_base = kv_base->prepare(ubatches); + if (heads_base.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + auto heads_swa = kv_swa->prepare(ubatches); + if (heads_swa.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + assert(heads_base.size() == heads_swa.size()); + + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); +} + +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); +} + +bool llama_kv_cache_unified_iswa::update(llama_context & lctx) { + bool res = false; + + res = res | kv_base->update(lctx); + res = res | kv_swa ->update(lctx); + + return res; +} + +void llama_kv_cache_unified_iswa::defrag_sched(float thold) { + kv_base->defrag_sched(thold); + kv_swa ->defrag_sched(thold); +} + +bool llama_kv_cache_unified_iswa::get_can_shift() const { + return kv_base->get_size() == kv_swa->get_size(); +} + +void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + kv_base->state_write(io, seq_id); + kv_swa ->state_write(io, seq_id); +} + +void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + kv_base->state_read(io, seq_id); + kv_swa ->state_read(io, seq_id); +} + +llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const { + return kv_base.get(); +} + +llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const { + return kv_swa.get(); +} + +// +// llama_kv_cache_unified_iswa_state +// + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv) : status(status) { + state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base())); + state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ())); +} + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv, + llama_sbatch sbatch, + std::vector heads_base, + std::vector heads_swa, + std::vector ubatches) + : status(status), + sbatch(std::move(sbatch)), + ubatches(std::move(ubatches)) { + // note: here we copy the ubatches. not sure if this is ideal + state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches)); + state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches)); + } + +llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; + +bool llama_kv_cache_unified_iswa_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + state_base->next(); + state_swa ->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_unified_iswa_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + bool res = true; + + res = res & state_base->apply(); + res = res & state_swa ->apply(); + + return res; +} + +std::vector & llama_kv_cache_unified_iswa_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; +} + +const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return state_base.get(); +} + +const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return state_swa.get(); +} diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h new file mode 100644 index 0000000000000..8b067da038af6 --- /dev/null +++ b/src/llama-kv-cache-unified-iswa.h @@ -0,0 +1,136 @@ +#pragma once + +#include "llama-kv-cache-unified.h" + +#include + +// +// llama_kv_cache_unified_iswa +// + +// utilizes two instances of llama_kv_cache_unified +// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers + +class llama_kv_cache_unified_iswa : public llama_kv_cache { +public: + llama_kv_cache_unified_iswa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool swa_full, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_ubatch, + uint32_t n_pad); + + ~llama_kv_cache_unified_iswa() = default; + + // + // llama_memory_i + // + + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // + // llama_kv_cache + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; + + llama_memory_state_ptr init_full() override; + + bool update(llama_context & lctx) override; + + void defrag_sched(float thold) override; + + bool get_can_shift() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + // + // llama_kv_cache_unified_iswa specific API + // + + llama_kv_cache_unified * get_base() const; + llama_kv_cache_unified * get_swa () const; + +private: + const llama_hparams & hparams; + + std::unique_ptr kv_base; + std::unique_ptr kv_swa; +}; + +class llama_kv_cache_unified_iswa_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_unified_iswa_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv); + + // used to create a state from a batch + llama_kv_cache_unified_iswa_state( + llama_memory_status status, + llama_kv_cache_unified_iswa * kv, + llama_sbatch sbatch, + std::vector heads_base, + std::vector heads_swa, + std::vector ubatches); + + virtual ~llama_kv_cache_unified_iswa_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_unified_iswa_state specific API + // + + const llama_kv_cache_unified_state * get_base() const; + const llama_kv_cache_unified_state * get_swa() const; + +private: + const llama_memory_status status; + + //llama_kv_cache_unified_iswa * kv; + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector ubatches; + + std::unique_ptr state_base; + std::unique_ptr state_swa; +}; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp new file mode 100644 index 0000000000000..a817154769a32 --- /dev/null +++ b/src/llama-kv-cache-unified.cpp @@ -0,0 +1,1717 @@ +#include "llama-kv-cache-unified.h" + +#include "llama-impl.h" +#include "llama-model.h" +#include "llama-context.h" + +#include +#include +#include +#include +#include +#include + +// +// llama_kv_cache_unified +// + +llama_kv_cache_unified::llama_kv_cache_unified( + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type) : + model(model), hparams(model.hparams), v_trans(v_trans), + n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + + GGML_ASSERT(kv_size % n_pad == 0); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map[buft] = ctx; + ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + head = 0; + + cells.resize(kv_size); + + for (uint32_t il = 0; il < hparams.n_layer; il++) { + if (filter && !filter(il)) { + LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); + continue; + } + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + const char * dev_name = "CPU"; + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + + if (offload) { + auto * dev = model.dev_layer(il); + buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); + } + + LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name); + + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + throw std::runtime_error("failed to create ggml context for kv cache"); + } + + ggml_tensor * k; + ggml_tensor * v; + + k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); + v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); + + ggml_format_name(k, "cache_k_l%d", il); + ggml_format_name(v, "cache_v_l%d", il); + + map_layer_ids[il] = layers.size(); + layers.push_back({ il, k, v }); + } + + // allocate tensors and initialize the buffers to avoid NaNs in the padding + for (auto it : ctx_map) { + auto * buft = it.first; + auto * ctx = it.second; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + throw std::runtime_error("failed to allocate buffer for kv cache"); + } + + LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + + ggml_backend_buffer_clear(buf, 0); + bufs.emplace_back(buf); + } + + { + const size_t memory_size_k = size_k_bytes(); + const size_t memory_size_v = size_v_bytes(); + + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } +} + +void llama_kv_cache_unified::clear() { + cells.reset(); + + head = 0; + + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } +} + +bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + uint32_t new_head = cells.size(); + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) { + if (new_head == cells.size()) { + new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cells.size() && new_head < head) { + head = new_head; + } + + return true; +} + +void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id_src)) { + cells.seq_add(i, seq_id_dst); + } + } +} + +void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { + uint32_t new_head = cells.size(); + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (cells.seq_keep(i, seq_id)) { + if (new_head == cells.size()) { + new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cells.size() && new_head < head) { + head = new_head; + } +} + +void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + if (shift == 0) { + return; + } + + uint32_t new_head = cells.size(); + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over all cells. + if (p0 == p1) { + return; + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id)) { + if (cells.pos_add(i, shift)) { + if (new_head == cells.size()) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + head = new_head != cells.size() ? new_head : 0; +} + +void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) { + return; + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id)) { + cells.pos_div(i, d); + } + } +} + +llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { + return cells.seq_pos_min(seq_id); +} + +llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { + return cells.seq_pos_max(seq_id); +} + +llama_memory_state_ptr llama_kv_cache_unified::init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) { + GGML_UNUSED(embd_pooled); + + auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); + + std::vector ubatches; + while (sbatch.n_tokens > 0) { + ubatches.push_back(sbatch.split_simple(n_ubatch)); + } + + auto heads = prepare(ubatches); + if (heads.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + this, std::move(sbatch), std::move(heads), std::move(ubatches)); +} + +llama_memory_state_ptr llama_kv_cache_unified::init_full() { + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); +} + +std::vector llama_kv_cache_unified::prepare(const std::vector & ubatches) { + std::vector res; + + struct state { + uint32_t head_old; // old position of the head, before placing the ubatch + uint32_t head_new; // new position of the head, after placing the ubatch + + llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch + }; + + // remember the old state of the cells so we can restore it in the end + std::vector states; + + bool success = true; + + for (const auto & ubatch : ubatches) { + // only find a suitable slot for the ubatch. don't modify the cells yet + const int32_t head_new = find_slot(ubatch); + if (head_new < 0) { + success = false; + break; + } + + // remeber the position that we found + res.push_back(head_new); + + // store the old state of the cells in the recovery stack + states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)}); + + // now emplace the ubatch + apply_ubatch(head_new, ubatch); + } + + // iterate backwards and restore the cells to their original state + for (auto it = states.rbegin(); it != states.rend(); ++it) { + cells.set(it->head_new, it->cells); + head = it->head_old; + } + + if (!success) { + return {}; + } + + return res; +} + +bool llama_kv_cache_unified::update(llama_context & lctx) { + bool updated = false; + + auto * sched = lctx.get_sched(); + + if (cells.get_has_shift()) { + if (!get_can_shift()) { + GGML_ABORT("The current KV cache / model configuration does not support K-shift"); + } + + LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); + + // apply K-shift if needed + if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { + ggml_backend_sched_reset(sched); + + auto * gf = lctx.graph_init(); + + auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__); + return updated; + } + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__); + return updated; + } + + res->set_inputs(nullptr); + + if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__); + return updated; + } + + updated = true; + } + + cells.reset_shift(); + } + + if (do_defrag) { + LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); + + if (defrag_prepare(lctx.graph_max_nodes())) { + ggml_backend_sched_reset(sched); + + auto * gf = lctx.graph_init(); + + auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); + return updated; + } + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); + return updated; + } + + res->set_inputs(nullptr); + + if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); + return updated; + } + + updated = true; + } + + do_defrag = false; + } + + return updated; +} + +void llama_kv_cache_unified::defrag_sched(float thold) { + const auto n_kv = cells.used_max_p1(); + + // - do not defrag small contexts (i.e. < 2048 tokens) + // - count the padding towards the number of used tokens + const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; + + // queue defragmentation for next llama_kv_cache_update + if (fragmentation > thold) { + LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); + + do_defrag = true; + } +} + +int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { + const uint32_t n_tokens = ubatch.n_tokens; + + uint32_t head_cur = this->head; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head_cur > cells.get_used() + 2*ubatch.n_tokens) { + head_cur = 0; + } + + // otherwise, one cell per token. + + if (n_tokens > cells.size()) { + LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); + return -1; + } + +//#define FIND_SLOT_DEBUG 1 +#if FIND_SLOT_DEBUG + LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", cells.used_max_p1(), cells.get_used(), head, n_swa); + + // for debugging + { + std::string ss; + if (n_swa > 0) { + for (uint32_t i = 0; i < cells.size(); ++i) { + if (cells.is_empty(i)) { + ss += '.'; + } else { + ss += std::to_string(cells.seq_get(i)); + } + if (i%256 == 255) { + ss += '\n'; + } + } + } + LLAMA_LOG_WARN("\n%s\n", ss.c_str()); + } + + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (cells.seq_pos_min(s) < 0) { + continue; + } + + LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[%d] = %5d, max[%d] = %5d\n", n_swa, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s)); + } +#endif + + uint32_t n_tested = 0; + + while (true) { + if (head_cur + n_tokens > cells.size()) { + n_tested += cells.size() - head_cur; + head_cur = 0; + continue; + } + + // keep track of what the minimum sequence positions would be if we accept the ubatch + llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + seq_pos_min[s] = cells.seq_pos_min(s); + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + const llama_pos pos = ubatch.pos[i]; + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + + // can we use this cell? either: + // - the cell is empty + // - the cell is occupied only by one sequence: + // - mask causally, if the sequence is the same as the one we are inserting + // - mask SWA, using current max pos for that sequence in the cache + // always insert in the cell with minimum pos + bool can_use = cells.is_empty(head_cur + i); + + if (!can_use && cells.seq_count(head_cur + i) == 1) { + const llama_pos pos_cell = cells.pos_get(head_cur + i); + + // causal mask + if (cells.seq_has(head_cur + i, seq_id)) { + can_use = pos_cell >= pos; + } + + if (!can_use) { + const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); + + // SWA mask + // note: we insert only in the cell with minimum pos in order to preserve the invariant that + // all positions between [pos_min, pos_max] for each sequence will be present in the cache + // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 + if (pos_cell == seq_pos_min[seq_id_cell] && + is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + seq_pos_min[seq_id_cell]++; + can_use = true; + } + } + } + + if (!can_use) { + found = false; + head_cur += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= cells.size()) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return -1; + } + } + + return head_cur; +} + +void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + if (!cells.is_empty(head_cur + i)) { + cells.rm(head_cur + i); + } + + cells.pos_set(head_cur + i, ubatch.pos[i]); + + for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) { + cells.seq_add(head_cur + i, ubatch.seq_id[i][j]); + } + } + + // move the head at the end of the slot + head = head_cur + ubatch.n_tokens; +} + +bool llama_kv_cache_unified::get_can_shift() const { + return true; +} + +uint32_t llama_kv_cache_unified::get_size() const { + return cells.size(); +} + +uint32_t llama_kv_cache_unified::get_n_kv() const { + return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); +} + +ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * k = layers[ikv].k; + + return ggml_view_3d(ctx, k, + hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, + ggml_row_size(k->type, hparams.n_embd_head_k), + ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), + 0); +} + +ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * v = layers[ikv].v; + + if (!v_trans) { + // note: v->nb[1] <= v->nb[2] + return ggml_view_3d(ctx, v, + hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, + ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] + 0); + } + + // note: v->nb[1] > v->nb[2] + return ggml_view_3d(ctx, v, + n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, + ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, v->ne[1]), // v->nb[2] + 0); +} + +ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * k = layers[ikv].k; + + const int64_t n_tokens = k_cur->ne[2]; + + ggml_tensor * k_view = ggml_view_1d(ctx, k, + n_tokens*hparams.n_embd_k_gqa(il), + ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); + + return ggml_cpy(ctx, k_cur, k_view); +} + +ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * v = layers[ikv].v; + + const int64_t n_tokens = v_cur->ne[2]; + + v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens); + + ggml_tensor * v_view = nullptr; + + if (!v_trans) { + v_view = ggml_view_1d(ctx, v, + n_tokens*hparams.n_embd_v_gqa(il), + ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur); + } else { + // note: the V cache is transposed when not using flash attention + v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il), + (v->ne[1])*ggml_element_size(v), + (head_cur)*ggml_element_size(v)); + + v_cur = ggml_transpose(ctx, v_cur); + } + + return ggml_cpy(ctx, v_cur, v_view); +} + +void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + float * data = (float *) dst->data; + + const auto n_kv = dst->ne[0]; + + // Use only the previous KV cells of the correct sequence for each token of the ubatch. + // It's assumed that if a token in the batch has multiple sequences, they are equivalent. + // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: + // Causal mask: + // xxx------- + // xxxx------ + // xxxxx----- + // Non-causal mask: + // xxxxx----- + // xxxxx----- + // xxxxx----- + // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 + for (int h = 0; h < 1; ++h) { + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j]; + + for (uint32_t i = 0; i < n_kv; ++i) { + float f = 0.0f; + + bool masked = false; + + if (cells.is_empty(i)) { + masked = true; + } else { + const llama_pos p0 = cells.pos_get(i); + + // mask the token if not the same sequence + masked = masked || (!cells.seq_has(i, seq_id)); + + // mask future tokens + masked = masked || (causal_attn && p0 > p1); + + // apply SWA if any + masked = masked || (is_masked_swa(p0, p1)); + + if (!masked && hparams.use_alibi) { + f = -std::abs(p0 - p1); + } + } + + if (masked) { + f = -INFINITY; + } + + data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + } + } + } + + // mask padded tokens + if (data) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (uint32_t j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } + } +} + +void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + int32_t * data = (int32_t *) dst->data; + + for (uint32_t i = 0; i < cells.size(); ++i) { + data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); + } +} + +void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { + const int64_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing + + int32_t * data = (int32_t *) dst->data; + + const int32_t n_kv = dst->ne[0]; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_kv; ++i) { + // the position when the cells is empty is irrelevant - it will be masked out later in the attention + const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i); + + data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false); + } + } + } +} + +size_t llama_kv_cache_unified::total_size() const { + size_t size = 0; + + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; +} + +size_t llama_kv_cache_unified::size_k_bytes() const { + size_t size_k_bytes = 0; + + for (const auto & layer : layers) { + size_k_bytes += ggml_nbytes(layer.k); + } + + return size_k_bytes; +} + +size_t llama_kv_cache_unified::size_v_bytes() const { + size_t size_v_bytes = 0; + + for (const auto & layer : layers) { + size_v_bytes += ggml_nbytes(layer.v); + } + + return size_v_bytes; +} + +ggml_tensor * llama_kv_cache_unified::build_rope_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale) const { + const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; + + const auto & yarn_ext_factor = cparams.yarn_ext_factor; + const auto & yarn_beta_fast = cparams.yarn_beta_fast; + const auto & yarn_beta_slow = cparams.yarn_beta_slow; + + const auto & n_rot = hparams.n_rot; + const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE + // @ngxson : this is a workaround + // for M-RoPE, we want to rotate the whole vector when doing KV shift + // a normal RoPE should work, we just need to use the correct ordering + // ref: https://github.com/ggml-org/llama.cpp/pull/13870 + ? LLAMA_ROPE_TYPE_NEOX + : hparams.rope_type; + + // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 + ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) + : cparams.yarn_attn_factor; + + ggml_tensor * tmp; + + if (ggml_is_quantized(cur->type)) { + // dequantize to f32 -> RoPE -> quantize back + tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); + + tmp = ggml_rope_ext(ctx, tmp, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + + tmp = ggml_cpy(ctx, tmp, cur); + } else { + // we rotate only the first n_rot dimensions + tmp = ggml_rope_ext_inplace(ctx, cur, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + } + + return tmp; +} + +class llm_graph_input_k_shift : public llm_graph_input_i { +public: + llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + virtual ~llm_graph_input_k_shift() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * k_shift; // I32 [kv_size] + + const llama_kv_cache_unified * kv_self; +}; + +void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + if (k_shift) { + kv_self->set_input_k_shift(k_shift); + } +} + +llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const { + auto res = std::make_unique(); + + const auto & n_embd_head_k = hparams.n_embd_head_k; + //const auto & n_embd_head_v = hparams.n_embd_head_v; + + //GGML_ASSERT(kv_self->size == n_ctx); + + auto inp = std::make_unique(this); + + inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx); + ggml_set_input(inp->k_shift); + + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + ggml_tensor * k = + ggml_view_3d(ctx, layer.k, + n_embd_head_k, n_head_kv, cells.size(), + ggml_row_size(layer.k->type, n_embd_head_k), + ggml_row_size(layer.k->type, n_embd_k_gqa), + 0); + + ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); + + ggml_build_forward_expand(gf, cur); + } + + res->add_input(std::move(inp)); + + return res; +} + +llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const { + auto res = std::make_unique(); + + const auto & ids = defrag_info.ids; + +#if 0 + // CPU defrag + // + // TODO: optimizations are possible: + // - multiple threads + // - avoid copying to the host memory when already there + // + // likely not worth the effort, as we have ggml_graph based defrag + // + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + + const uint32_t kv_size = size; + + std::vector buf_k; + std::vector buf_v; + + for (uint32_t il = 0; il < n_layer; ++il) { + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); + + const size_t v_size_el = ggml_type_size(v_l[il]->type); + const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); + + buf_k.resize(k_size); + buf_v.resize(v_size); + + ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); + + // batch move [i, i+nm) to [id, id+nm) + // note: cells can move only to a lower index + for (uint32_t i = 0; i < n_kv; ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == n_kv) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < n_kv && ids[i + nm] == id + nm) { + nm++; + } + + // move keys + { + const int64_t os = i*k_size_row; + const int64_t od = id*k_size_row; + + memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); + } + + // move values (note: they are transposed) + { + const int64_t os = i; + const int64_t od = id; + + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); + } + } + + i += nm - 1; + } + + ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); + } +#else + for (uint32_t i = 0; i < ids.size(); ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == ids.size()) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < ids.size() && ids[i + nm] == id + nm) { + nm++; + } + + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + + ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k, + n_embd_k_gqa, nm, + ggml_row_size(layer.k->type, n_embd_k_gqa), + ggml_row_size(layer.k->type, n_embd_k_gqa*i)); + + ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k, + n_embd_k_gqa, nm, + ggml_row_size(layer.k->type, n_embd_k_gqa), + ggml_row_size(layer.k->type, n_embd_k_gqa*id)); + + ggml_tensor * view_v_src; + ggml_tensor * view_v_dst; + + if (cparams.flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = ggml_view_2d(ctx, layer.v, + n_embd_v_gqa, nm, + ggml_row_size(layer.v->type, n_embd_v_gqa), + ggml_row_size(layer.v->type, n_embd_v_gqa*i)); + + view_v_dst = ggml_view_2d(ctx, layer.v, + n_embd_v_gqa, nm, + ggml_row_size(layer.v->type, n_embd_v_gqa), + ggml_row_size(layer.v->type, n_embd_v_gqa*id)); + } else { + view_v_src = ggml_view_2d(ctx, layer.v, + nm, n_embd_v_gqa, + ggml_row_size(layer.v->type, cells.size()), + ggml_row_size(layer.v->type, i)); + + view_v_dst = ggml_view_2d(ctx, layer.v, + nm, n_embd_v_gqa, + ggml_row_size(layer.v->type, cells.size()), + ggml_row_size(layer.v->type, id)); + } + + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst)); + } + + i += nm - 1; + } + + //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); +#endif + + return res; +} + +bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { + const uint32_t n_layer = layers.size(); + + const uint32_t n_kv = cells.used_max_p1(); + const uint32_t n_used = cells.get_used(); + + assert(n_used <= n_kv); + + //const int64_t t_start = ggml_time_us(); + + // number of cells moved + uint32_t n_moves = 0; + + // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag) + // - source view, destination view, copy operation + // - x2 for keys and values + //const uint32_t max_moves = max_nodes()/(6*n_layer); + // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 + const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); + + // determine which KV cells to move where + // + // cell i moves to ids[i] + // + // if ids[i] == i || ids[i] == n_kv, then cell i is not moved + // + auto & ids = defrag_info.ids; + + ids.clear(); + ids.resize(n_kv, n_kv); + + for (uint32_t i0 = 0; i0 < n_used; ++i0) { + if (!cells.is_empty(i0)) { + ids[i0] = i0; + + continue; + } + + // found a hole - fill it with data from the end of the cache + + uint32_t nh = 1; + + // determine the size of the hole + while (i0 + nh < n_used && cells.is_empty(i0 + nh)) { + nh++; + } + + uint32_t nf = 0; + uint32_t is = n_kv - 1; + + // starting from the end, find nh non-empty cells + for (; is > i0; --is) { + if (cells.is_empty(is) || ids[is] != n_kv) { + continue; + } + + // non-empty cell which is not yet moved + nf++; + + if (nf == nh) { + break; + } + } + + // this can only happen if `n_used` is not accurate, which would be a bug + GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh"); + + nf = 0; + + uint32_t i1 = is; + + // are we moving a continuous block of memory? + bool cont = false; + + // should we stop searching for the next move? + bool stop = false; + + // go back and move the nf cells to the hole + for (; i1 < n_kv; ++i1) { + if (cells.is_empty(i1) || ids[i1] != n_kv) { + if (n_moves == max_moves) { + stop = true; + break; + } + + cont = false; + continue; + } + + // this cell goes to (i0 + nf) + ids[i1] = i0 + nf; + + // move the cell meta data + cells.mv(i1, i0 + nf); + + head = n_used; + + if (!cont) { + n_moves++; + cont = true; + } + + nf++; + + if (nf == nh) { + break; + } + } + + if (stop || n_moves == max_moves) { + break; + } + + //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh); + + i0 += nh - 1; + } + + if (n_moves == 0) { + return false; + } + + LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); + + LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); + + return true; +} + +bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { + assert(p0 >= 0 && p1 >= 0); + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) n_swa) { + return true; + } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; + } + + return false; +} + +void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = cells.size(); + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { + ++cell_count; + if (cell_range_begin == cells.size()) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != cells.size()) { + cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = cells.size(); + } + } + } + + if (cell_range_begin != cells.size()) { + cell_ranges.emplace_back(cell_range_begin, cells.size()); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + + io.write(&cell_count, sizeof(cell_count)); + + state_write_meta(io, cell_ranges, seq_id); + state_write_data(io, cell_ranges); +} + +void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); + + bool res = true; + res = res && state_read_meta(io, cell_count, seq_id); + res = res && state_read_data(io, cell_count); + + if (!res) { + if (seq_id == -1) { + clear(); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); + } +} + +void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + std::vector seq_ids; + + for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) { + if (cur == seq_id || seq_id == -1) { + if (cells.seq_has(i, cur)) { + seq_ids.push_back(cur); + } + } + } + + const llama_pos pos = cells.pos_get(i); + const uint32_t n_seq_id = seq_ids.size(); + + io.write(&pos, sizeof(pos)); + io.write(&n_seq_id, sizeof(n_seq_id)); + + for (const auto & seq_id : seq_ids) { + io.write(&seq_id, sizeof(seq_id)); + } + } + } +} + +void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { + const uint32_t v_trans = this->v_trans ? 1 : 0; + const uint32_t n_layer = layers.size(); + + io.write(&v_trans, sizeof(v_trans)); + io.write(&n_layer, sizeof(n_layer)); + + std::vector tmp_buf; + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Write key type + const int32_t k_type_i = (int32_t)layer.k->type; + io.write(&k_type_i, sizeof(k_type_i)); + + // Write row size of key + const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + io.write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * k_size_row; + io.write_tensor(layer.k, range.first * k_size_row, buf_size); + } + } + + if (!v_trans) { + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)layer.v->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write row size of value + const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + io.write(&v_size_row, sizeof(v_size_row)); + + // Read each range of cells of v_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * v_size_row; + io.write_tensor(layer.v, range.first * v_size_row, buf_size); + } + } + } else { + // When v is transposed, we also need the element size and get the element ranges from each row + const uint32_t kv_size = cells.size(); + + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)layer.v->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const uint32_t v_size_el = ggml_type_size(layer.v->type); + io.write(&v_size_el, sizeof(v_size_el)); + + // Write GQA embedding size + io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + const size_t buf_size = range_size * v_size_el; + io.write_tensor(layer.v, src_offset, buf_size); + } + } + } + } +} + +bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { + if (dest_seq_id != -1) { + // single sequence + + seq_rm(dest_seq_id, -1, -1); + + llama_sbatch sbatch; + llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + + batch.n_tokens = cell_count; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 1) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + // read the sequence id, but directly discard it - we will use dest_seq_id instead + { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + } + + batch.pos[i] = pos; + batch.n_seq_id[i] = n_seq_id; + batch.seq_id[i] = &dest_seq_id; + } + + const auto head_cur = find_slot(batch); + if (head_cur < 0) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + + apply_ubatch(head_cur, batch); + + // keep the head at the old position because we will read the KV data into it in state_read_data() + head = head_cur; + + // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(head_cur + cell_count <= cells.size()); + GGML_ASSERT(cells.pos_get(head_cur) == batch.pos[0]); + GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]); + GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id)); + GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id)); + } else { + // whole KV cache restore + + if (cell_count > cells.size()) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } + + clear(); + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + cells.pos_set(i, pos); + + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + + if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max); + return false; + } + + cells.seq_add(i, seq_id); + } + } + + head = 0; + } + + return true; +} + +bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { + uint32_t v_trans; + uint32_t n_layer; + + io.read_to(&v_trans, sizeof(v_trans)); + io.read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != layers.size()) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); + return false; + } + + if (cell_count > cells.size()) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size()); + return false; + } + + if (this->v_trans != (bool) v_trans) { + LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); + return false; + } + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Read type of key + int32_t k_type_i_ref; + io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + const int32_t k_type_i = (int32_t) layer.k->type; + if (k_type_i != k_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t k_size_row_ref; + io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + if (k_size_row != k_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + } + } + + if (!this->v_trans) { + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)layer.v->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read row size of value + uint64_t v_size_row_ref; + io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)layer.v->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read element size of value + uint32_t v_size_el_ref; + io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + const size_t v_size_el = ggml_type_size(layer.v->type); + if (v_size_el != v_size_el_ref) { + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); + return false; + } + + // Read GQA embedding size + uint32_t n_embd_v_gqa_ref; + io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + return false; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (head + j * cells.size()) * v_size_el; + ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } + } + } + + return true; +} + +// +// llama_kv_cache_unified_state +// + +llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv) : status(status), kv(kv) { + n_kv = kv->get_size(); + head = 0; + } + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv, + llama_sbatch sbatch, + std::vector heads, + std::vector ubatches) + : status(status), + kv(kv), + sbatch(std::move(sbatch)), + heads(std::move(heads)), + ubatches(std::move(ubatches)) { + } + +llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; + +bool llama_kv_cache_unified_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_unified_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + kv->apply_ubatch(heads[i_next], ubatches[i_next]); + + n_kv = kv->get_n_kv(); + head = heads[i_next]; + + return true; +} + +std::vector & llama_kv_cache_unified_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_unified_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +uint32_t llama_kv_cache_unified_state::get_n_kv() const { + return n_kv; +} + +ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const { + return kv->get_k(ctx, il, n_kv); +} + +ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const { + return kv->get_v(ctx, il, n_kv); +} + +ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { + return kv->cpy_k(ctx, k_cur, il, head); +} + +ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { + return kv->cpy_v(ctx, v_cur, il, head); +} + +void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const { + kv->set_input_k_shift(dst); +} + +void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + kv->set_input_kq_mask(dst, ubatch, causal_attn); +} + +void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { + kv->set_input_pos_bucket(dst, ubatch); +} + +uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { + // the FA kernels require padding to avoid extra runtime boundary checks + return cparams.flash_attn ? 256u : 32u; +} diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h new file mode 100644 index 0000000000000..1f1d44b97c2ac --- /dev/null +++ b/src/llama-kv-cache-unified.h @@ -0,0 +1,278 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cache.h" +#include "llama-kv-cells.h" + +#include +#include + +struct llama_cparams; +struct llama_hparams; +struct llama_model; +struct llama_context; + +// +// llama_kv_cache_unified +// + +class llama_kv_cache_unified : public llama_kv_cache { +public: + static uint32_t get_padding(const llama_cparams & cparams); + + // this callback is used to filter out layers that should not be included in the cache + using layer_filter_cb = std::function; + + llama_kv_cache_unified( + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type); + + ~llama_kv_cache_unified() = default; + + // + // llama_memory_i + // + + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // + // llama_kv_cache + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; + + llama_memory_state_ptr init_full() override; + + bool update(llama_context & lctx) override; + + void defrag_sched(float thold) override; + + bool get_can_shift() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + // + // llama_kv_cache_unified specific API + // + + uint32_t get_size() const; + + // + // graph_build API + // + + uint32_t get_n_kv() const; + + // get views of the current state of the cache + ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; + + // store k_cur and v_cur in the cache based on the provided head location + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const; + + // + // preparation API + // + + // find places for the provided ubatches in the cache, returns the head locations + // return empty vector on failure + std::vector prepare(const std::vector & ubatches); + + // return the cell position where we can insert the ubatch + // return -1 on failure to find a contiguous slot of kv cells + int32_t find_slot(const llama_ubatch & ubatch) const; + + // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens) + void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch); + + // + // set_input API + // + + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_k_shift (ggml_tensor * dst) const; + void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + +private: + const llama_model & model; + const llama_hparams & hparams; + + struct kv_layer { + // layer index in the model + // note: can be different from the layer index in the KV cache + uint32_t il; + + ggml_tensor * k; + ggml_tensor * v; + }; + + bool do_defrag = false; + bool v_trans = true; // the value tensor is transposed + + // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) + // note: this is not part of the KV state and it's only used to speed-up the find_slot() method + uint32_t head = 0; + + const uint32_t n_seq_max = 1; + + // required padding + const uint32_t n_pad = 1; + + // SWA + const uint32_t n_swa = 0; + + const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; + + std::vector ctxs; + std::vector bufs; + + llama_kv_cells_unified cells; + + std::vector layers; + + // model layer id -> KV cache layer id + std::unordered_map map_layer_ids; + + // defrag + struct { + std::vector ids; + } defrag_info; + + // return true if cells have been moved + bool defrag_prepare(int32_t n_max_nodes); + + size_t total_size() const; + + size_t size_k_bytes() const; + size_t size_v_bytes() const; + + bool is_masked_swa(llama_pos p0, llama_pos p1) const; + + ggml_tensor * build_rope_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale) const; + + llm_graph_result_ptr build_graph_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const; + + llm_graph_result_ptr build_graph_defrag( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const; + + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; + + bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t cell_count); +}; + +class llama_kv_cache_unified_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_unified_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv); + + // used to create a state from a batch + llama_kv_cache_unified_state( + llama_memory_status status, + llama_kv_cache_unified * kv, + llama_sbatch sbatch, + std::vector heads, + std::vector ubatches); + + virtual ~llama_kv_cache_unified_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_unified_state specific API + // + + uint32_t get_n_kv() const; + + // get views of the current state of the cache + ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; + + // store k_cur and v_cur in the cache based on the provided head location + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; + + void set_input_k_shift(ggml_tensor * dst) const; + + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + +private: + const llama_memory_status status; + + llama_kv_cache_unified * kv; + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector heads; + std::vector ubatches; + + // + // data needed for building the compute graph for the current ubatch: + // + + // a heuristic, to avoid attending the full cache if it is not yet utilized + // as the cache gets filled, the benefit from this heuristic disappears + int32_t n_kv; + + // the beginning of the current slot in which the ubatch will be inserted + int32_t head; +}; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 447c09c969baa..aefd23e324791 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1,3081 +1 @@ #include "llama-kv-cache.h" - -#include "llama-impl.h" -#include "llama-batch.h" -#include "llama-cparams.h" -#include "llama-model.h" -#include "llama-context.h" - -#include -#include -#include -#include -#include -#include - -// -// llama_kv_cache_unified -// - -llama_kv_cache_unified::llama_kv_cache_unified( - const llama_model & model, - layer_filter_cb && filter, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max, - uint32_t n_pad, - uint32_t n_swa, - llama_swa_type swa_type) : - model(model), hparams(model.hparams), v_trans(v_trans), - n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { - - GGML_ASSERT(kv_size % n_pad == 0); - - // create a context for each buffer type - std::map ctx_map; - auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { - auto it = ctx_map.find(buft); - if (it == ctx_map.end()) { - ggml_init_params params = { - /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()), - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - ggml_context * ctx = ggml_init(params); - if (!ctx) { - return nullptr; - } - - ctx_map[buft] = ctx; - ctxs.emplace_back(ctx); - - return ctx; - } - - return it->second; - }; - - head = 0; - - cells.resize(kv_size); - - for (uint32_t il = 0; il < hparams.n_layer; il++) { - if (filter && !filter(il)) { - LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); - continue; - } - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - const char * dev_name = "CPU"; - - ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); - - if (offload) { - auto * dev = model.dev_layer(il); - buft = ggml_backend_dev_buffer_type(dev); - - dev_name = ggml_backend_dev_name(dev); - } - - LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name); - - ggml_context * ctx = ctx_for_buft(buft); - if (!ctx) { - throw std::runtime_error("failed to create ggml context for kv cache"); - } - - ggml_tensor * k; - ggml_tensor * v; - - k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); - v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); - - ggml_format_name(k, "cache_k_l%d", il); - ggml_format_name(v, "cache_v_l%d", il); - - map_layer_ids[il] = layers.size(); - layers.push_back({ il, k, v }); - } - - // allocate tensors and initialize the buffers to avoid NaNs in the padding - for (auto it : ctx_map) { - auto * buft = it.first; - auto * ctx = it.second; - - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); - if (!buf) { - throw std::runtime_error("failed to allocate buffer for kv cache"); - } - - LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); - - ggml_backend_buffer_clear(buf, 0); - bufs.emplace_back(buf); - } - - { - const size_t memory_size_k = size_k_bytes(); - const size_t memory_size_v = size_v_bytes(); - - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, - ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); - } -} - -void llama_kv_cache_unified::clear() { - cells.reset(); - - head = 0; - - for (auto & buf : bufs) { - ggml_backend_buffer_clear(buf.get(), 0); - } -} - -bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - uint32_t new_head = cells.size(); - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.pos_in(i, p0, p1)) { - continue; - } - - if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) { - if (new_head == cells.size()) { - new_head = i; - } - } - } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cells.size() && new_head < head) { - head = new_head; - } - - return true; -} - -void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - if (seq_id_src == seq_id_dst) { - return; - } - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.pos_in(i, p0, p1)) { - continue; - } - - if (cells.seq_has(i, seq_id_src)) { - cells.seq_add(i, seq_id_dst); - } - } -} - -void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { - uint32_t new_head = cells.size(); - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (cells.seq_keep(i, seq_id)) { - if (new_head == cells.size()) { - new_head = i; - } - } - } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cells.size() && new_head < head) { - head = new_head; - } -} - -void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { - if (shift == 0) { - return; - } - - uint32_t new_head = cells.size(); - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // If there is no range then return early to avoid looping over all cells. - if (p0 == p1) { - return; - } - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.pos_in(i, p0, p1)) { - continue; - } - - if (cells.seq_has(i, seq_id)) { - if (cells.pos_add(i, shift)) { - if (new_head == cells.size()) { - new_head = i; - } - } - } - } - - // If we freed up a slot, set head to it so searching can start there. - // Otherwise we just start the next search from the beginning. - head = new_head != cells.size() ? new_head : 0; -} - -void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - if (d == 1) { - return; - } - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // If there is no range then return early to avoid looping over the cache. - if (p0 == p1) { - return; - } - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.pos_in(i, p0, p1)) { - continue; - } - - if (cells.seq_has(i, seq_id)) { - cells.pos_div(i, d); - } - } -} - -llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { - return cells.seq_pos_min(seq_id); -} - -llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { - return cells.seq_pos_max(seq_id); -} - -llama_memory_state_ptr llama_kv_cache_unified::init_batch( - const llama_batch & batch, - uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) { - GGML_UNUSED(embd_pooled); - - auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); - - std::vector ubatches; - while (sbatch.n_tokens > 0) { - ubatches.push_back(sbatch.split_simple(n_ubatch)); - } - - auto heads = prepare(ubatches); - if (heads.empty()) { - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); - } - - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, - this, std::move(sbatch), std::move(heads), std::move(ubatches)); -} - -llama_memory_state_ptr llama_kv_cache_unified::init_full() { - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); -} - -std::vector llama_kv_cache_unified::prepare(const std::vector & ubatches) { - std::vector res; - - struct state { - uint32_t head_old; // old position of the head, before placing the ubatch - uint32_t head_new; // new position of the head, after placing the ubatch - - llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch - }; - - // remember the old state of the cells so we can restore it in the end - std::vector states; - - bool success = true; - - for (const auto & ubatch : ubatches) { - // only find a suitable slot for the ubatch. don't modify the cells yet - const int32_t head_new = find_slot(ubatch); - if (head_new < 0) { - success = false; - break; - } - - // remeber the position that we found - res.push_back(head_new); - - // store the old state of the cells in the recovery stack - states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)}); - - // now emplace the ubatch - apply_ubatch(head_new, ubatch); - } - - // iterate backwards and restore the cells to their original state - for (auto it = states.rbegin(); it != states.rend(); ++it) { - cells.set(it->head_new, it->cells); - head = it->head_old; - } - - if (!success) { - return {}; - } - - return res; -} - -bool llama_kv_cache_unified::update(llama_context & lctx) { - bool updated = false; - - auto * sched = lctx.get_sched(); - - if (cells.get_has_shift()) { - if (!get_can_shift()) { - GGML_ABORT("The current KV cache / model configuration does not support K-shift"); - } - - LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); - - // apply K-shift if needed - if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { - ggml_backend_sched_reset(sched); - - auto * gf = lctx.graph_init(); - - auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf); - if (!res) { - LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__); - return updated; - } - - if (!ggml_backend_sched_alloc_graph(sched, gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__); - return updated; - } - - res->set_inputs(nullptr); - - if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { - LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__); - return updated; - } - - updated = true; - } - - cells.reset_shift(); - } - - if (do_defrag) { - LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - - if (defrag_prepare(lctx.graph_max_nodes())) { - ggml_backend_sched_reset(sched); - - auto * gf = lctx.graph_init(); - - auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf); - if (!res) { - LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); - return updated; - } - - if (!ggml_backend_sched_alloc_graph(sched, gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); - return updated; - } - - res->set_inputs(nullptr); - - if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { - LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); - return updated; - } - - updated = true; - } - - do_defrag = false; - } - - return updated; -} - -void llama_kv_cache_unified::defrag_sched(float thold) { - const auto n_kv = cells.used_max_p1(); - - // - do not defrag small contexts (i.e. < 2048 tokens) - // - count the padding towards the number of used tokens - const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; - - // queue defragmentation for next llama_kv_cache_update - if (fragmentation > thold) { - LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); - - do_defrag = true; - } -} - -int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { - const uint32_t n_tokens = ubatch.n_tokens; - - uint32_t head_cur = this->head; - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (head_cur > cells.get_used() + 2*ubatch.n_tokens) { - head_cur = 0; - } - - // otherwise, one cell per token. - - if (n_tokens > cells.size()) { - LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); - return -1; - } - -//#define FIND_SLOT_DEBUG 1 -#if FIND_SLOT_DEBUG - LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", cells.used_max_p1(), cells.get_used(), head, n_swa); - - // for debugging - { - std::string ss; - if (n_swa > 0) { - for (uint32_t i = 0; i < cells.size(); ++i) { - if (cells.is_empty(i)) { - ss += '.'; - } else { - ss += std::to_string(cells.seq_get(i)); - } - if (i%256 == 255) { - ss += '\n'; - } - } - } - LLAMA_LOG_WARN("\n%s\n", ss.c_str()); - } - - for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { - if (cells.seq_pos_min(s) < 0) { - continue; - } - - LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[%d] = %5d, max[%d] = %5d\n", n_swa, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s)); - } -#endif - - uint32_t n_tested = 0; - - while (true) { - if (head_cur + n_tokens > cells.size()) { - n_tested += cells.size() - head_cur; - head_cur = 0; - continue; - } - - // keep track of what the minimum sequence positions would be if we accept the ubatch - llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; - for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { - seq_pos_min[s] = cells.seq_pos_min(s); - } - - bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { - const llama_pos pos = ubatch.pos[i]; - const llama_seq_id seq_id = ubatch.seq_id[i][0]; - - // can we use this cell? either: - // - the cell is empty - // - the cell is occupied only by one sequence: - // - mask causally, if the sequence is the same as the one we are inserting - // - mask SWA, using current max pos for that sequence in the cache - // always insert in the cell with minimum pos - bool can_use = cells.is_empty(head_cur + i); - - if (!can_use && cells.seq_count(head_cur + i) == 1) { - const llama_pos pos_cell = cells.pos_get(head_cur + i); - - // causal mask - if (cells.seq_has(head_cur + i, seq_id)) { - can_use = pos_cell >= pos; - } - - if (!can_use) { - const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); - - // SWA mask - // note: we insert only in the cell with minimum pos in order to preserve the invariant that - // all positions between [pos_min, pos_max] for each sequence will be present in the cache - // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 - if (pos_cell == seq_pos_min[seq_id_cell] && - is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { - seq_pos_min[seq_id_cell]++; - can_use = true; - } - } - } - - if (!can_use) { - found = false; - head_cur += i + 1; - n_tested += i + 1; - break; - } - } - - if (found) { - break; - } - - if (n_tested >= cells.size()) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return -1; - } - } - - return head_cur; -} - -void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { - for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { - if (!cells.is_empty(head_cur + i)) { - cells.rm(head_cur + i); - } - - cells.pos_set(head_cur + i, ubatch.pos[i]); - - for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) { - cells.seq_add(head_cur + i, ubatch.seq_id[i][j]); - } - } - - // move the head at the end of the slot - head = head_cur + ubatch.n_tokens; -} - -bool llama_kv_cache_unified::get_can_shift() const { - return true; -} - -uint32_t llama_kv_cache_unified::get_size() const { - return cells.size(); -} - -uint32_t llama_kv_cache_unified::get_n_kv() const { - return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); -} - -ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const { - const int32_t ikv = map_layer_ids.at(il); - - auto * k = layers[ikv].k; - - return ggml_view_3d(ctx, k, - hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, - ggml_row_size(k->type, hparams.n_embd_head_k), - ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), - 0); -} - -ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const { - const int32_t ikv = map_layer_ids.at(il); - - auto * v = layers[ikv].v; - - if (!v_trans) { - // note: v->nb[1] <= v->nb[2] - return ggml_view_3d(ctx, v, - hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, - ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] - 0); - } - - // note: v->nb[1] > v->nb[2] - return ggml_view_3d(ctx, v, - n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, - ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, v->ne[1]), // v->nb[2] - 0); -} - -ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const { - const int32_t ikv = map_layer_ids.at(il); - - auto * k = layers[ikv].k; - - const int64_t n_tokens = k_cur->ne[2]; - - ggml_tensor * k_view = ggml_view_1d(ctx, k, - n_tokens*hparams.n_embd_k_gqa(il), - ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); - - return ggml_cpy(ctx, k_cur, k_view); -} - -ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const { - const int32_t ikv = map_layer_ids.at(il); - - auto * v = layers[ikv].v; - - const int64_t n_tokens = v_cur->ne[2]; - - v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens); - - ggml_tensor * v_view = nullptr; - - if (!v_trans) { - v_view = ggml_view_1d(ctx, v, - n_tokens*hparams.n_embd_v_gqa(il), - ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur); - } else { - // note: the V cache is transposed when not using flash attention - v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il), - (v->ne[1])*ggml_element_size(v), - (head_cur)*ggml_element_size(v)); - - v_cur = ggml_transpose(ctx, v_cur); - } - - return ggml_cpy(ctx, v_cur, v_view); -} - -void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { - const int64_t n_tokens = ubatch->n_tokens; - const int64_t n_seq_tokens = ubatch->n_seq_tokens; - const int64_t n_seqs = ubatch->n_seqs; - - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - float * data = (float *) dst->data; - - const auto n_kv = dst->ne[0]; - - // Use only the previous KV cells of the correct sequence for each token of the ubatch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: - // Causal mask: - // xxx------- - // xxxx------ - // xxxxx----- - // Non-causal mask: - // xxxxx----- - // xxxxx----- - // xxxxx----- - // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 - for (int h = 0; h < 1; ++h) { - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch->seq_id[s][0]; - - for (int j = 0; j < n_seq_tokens; ++j) { - const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j]; - - for (uint32_t i = 0; i < n_kv; ++i) { - float f = 0.0f; - - bool masked = false; - - if (cells.is_empty(i)) { - masked = true; - } else { - const llama_pos p0 = cells.pos_get(i); - - // mask the token if not the same sequence - masked = masked || (!cells.seq_has(i, seq_id)); - - // mask future tokens - masked = masked || (causal_attn && p0 > p1); - - // apply SWA if any - masked = masked || (is_masked_swa(p0, p1)); - - if (!masked && hparams.use_alibi) { - f = -std::abs(p0 - p1); - } - } - - if (masked) { - f = -INFINITY; - } - - data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; - } - } - } - - // mask padded tokens - if (data) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (uint32_t j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - } - } - } - } -} - -void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - - int32_t * data = (int32_t *) dst->data; - - for (uint32_t i = 0; i < cells.size(); ++i) { - data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); - } -} - -void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { - const int64_t n_tokens = ubatch->n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing - - int32_t * data = (int32_t *) dst->data; - - const int32_t n_kv = dst->ne[0]; - - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_kv; ++i) { - // the position when the cells is empty is irrelevant - it will be masked out later in the attention - const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i); - - data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false); - } - } - } -} - -size_t llama_kv_cache_unified::total_size() const { - size_t size = 0; - - for (const auto & buf : bufs) { - size += ggml_backend_buffer_get_size(buf.get()); - } - - return size; -} - -size_t llama_kv_cache_unified::size_k_bytes() const { - size_t size_k_bytes = 0; - - for (const auto & layer : layers) { - size_k_bytes += ggml_nbytes(layer.k); - } - - return size_k_bytes; -} - -size_t llama_kv_cache_unified::size_v_bytes() const { - size_t size_v_bytes = 0; - - for (const auto & layer : layers) { - size_v_bytes += ggml_nbytes(layer.v); - } - - return size_v_bytes; -} - -ggml_tensor * llama_kv_cache_unified::build_rope_shift( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale) const { - const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; - - const auto & yarn_ext_factor = cparams.yarn_ext_factor; - const auto & yarn_beta_fast = cparams.yarn_beta_fast; - const auto & yarn_beta_slow = cparams.yarn_beta_slow; - - const auto & n_rot = hparams.n_rot; - const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE - // @ngxson : this is a workaround - // for M-RoPE, we want to rotate the whole vector when doing KV shift - // a normal RoPE should work, we just need to use the correct ordering - // ref: https://github.com/ggml-org/llama.cpp/pull/13870 - ? LLAMA_ROPE_TYPE_NEOX - : hparams.rope_type; - - // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. - // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. - const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 - ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) - : cparams.yarn_attn_factor; - - ggml_tensor * tmp; - - if (ggml_is_quantized(cur->type)) { - // dequantize to f32 -> RoPE -> quantize back - tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); - - tmp = ggml_rope_ext(ctx, tmp, - shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); - - tmp = ggml_cpy(ctx, tmp, cur); - } else { - // we rotate only the first n_rot dimensions - tmp = ggml_rope_ext_inplace(ctx, cur, - shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); - } - - return tmp; -} - -class llm_graph_input_k_shift : public llm_graph_input_i { -public: - llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} - virtual ~llm_graph_input_k_shift() = default; - - void set_input(const llama_ubatch * ubatch) override; - - ggml_tensor * k_shift; // I32 [kv_size] - - const llama_kv_cache_unified * kv_self; -}; - -void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { - GGML_UNUSED(ubatch); - - if (k_shift) { - kv_self->set_input_k_shift(k_shift); - } -} - -llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const { - auto res = std::make_unique(); - - const auto & n_embd_head_k = hparams.n_embd_head_k; - //const auto & n_embd_head_v = hparams.n_embd_head_v; - - //GGML_ASSERT(kv_self->size == n_ctx); - - auto inp = std::make_unique(this); - - inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx); - ggml_set_input(inp->k_shift); - - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const int64_t n_head_kv = hparams.n_head_kv(il); - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - - const float freq_base_l = model.get_rope_freq_base (cparams, il); - const float freq_scale_l = model.get_rope_freq_scale(cparams, il); - - ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); - - ggml_tensor * k = - ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, cells.size(), - ggml_row_size(layer.k->type, n_embd_head_k), - ggml_row_size(layer.k->type, n_embd_k_gqa), - 0); - - ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); - - ggml_build_forward_expand(gf, cur); - } - - res->add_input(std::move(inp)); - - return res; -} - -llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const { - auto res = std::make_unique(); - - const auto & ids = defrag_info.ids; - -#if 0 - // CPU defrag - // - // TODO: optimizations are possible: - // - multiple threads - // - avoid copying to the host memory when already there - // - // likely not worth the effort, as we have ggml_graph based defrag - // - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - - const uint32_t kv_size = size; - - std::vector buf_k; - std::vector buf_v; - - for (uint32_t il = 0; il < n_layer; ++il) { - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); - - const size_t v_size_el = ggml_type_size(v_l[il]->type); - const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); - - buf_k.resize(k_size); - buf_v.resize(v_size); - - ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); - - // batch move [i, i+nm) to [id, id+nm) - // note: cells can move only to a lower index - for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == n_kv) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < n_kv && ids[i + nm] == id + nm) { - nm++; - } - - // move keys - { - const int64_t os = i*k_size_row; - const int64_t od = id*k_size_row; - - memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); - } - - // move values (note: they are transposed) - { - const int64_t os = i; - const int64_t od = id; - - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); - } - } - - i += nm - 1; - } - - ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); - } -#else - for (uint32_t i = 0; i < ids.size(); ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == ids.size()) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < ids.size() && ids[i + nm] == id + nm) { - nm++; - } - - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - - ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k, - n_embd_k_gqa, nm, - ggml_row_size(layer.k->type, n_embd_k_gqa), - ggml_row_size(layer.k->type, n_embd_k_gqa*i)); - - ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k, - n_embd_k_gqa, nm, - ggml_row_size(layer.k->type, n_embd_k_gqa), - ggml_row_size(layer.k->type, n_embd_k_gqa*id)); - - ggml_tensor * view_v_src; - ggml_tensor * view_v_dst; - - if (cparams.flash_attn) { - // NOTE: the V cache is not transposed when using flash attention - view_v_src = ggml_view_2d(ctx, layer.v, - n_embd_v_gqa, nm, - ggml_row_size(layer.v->type, n_embd_v_gqa), - ggml_row_size(layer.v->type, n_embd_v_gqa*i)); - - view_v_dst = ggml_view_2d(ctx, layer.v, - n_embd_v_gqa, nm, - ggml_row_size(layer.v->type, n_embd_v_gqa), - ggml_row_size(layer.v->type, n_embd_v_gqa*id)); - } else { - view_v_src = ggml_view_2d(ctx, layer.v, - nm, n_embd_v_gqa, - ggml_row_size(layer.v->type, cells.size()), - ggml_row_size(layer.v->type, i)); - - view_v_dst = ggml_view_2d(ctx, layer.v, - nm, n_embd_v_gqa, - ggml_row_size(layer.v->type, cells.size()), - ggml_row_size(layer.v->type, id)); - } - - ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); - ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst)); - } - - i += nm - 1; - } - - //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); -#endif - - return res; -} - -bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { - const uint32_t n_layer = layers.size(); - - const uint32_t n_kv = cells.used_max_p1(); - const uint32_t n_used = cells.get_used(); - - assert(n_used <= n_kv); - - //const int64_t t_start = ggml_time_us(); - - // number of cells moved - uint32_t n_moves = 0; - - // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag) - // - source view, destination view, copy operation - // - x2 for keys and values - //const uint32_t max_moves = max_nodes()/(6*n_layer); - // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 - const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); - - // determine which KV cells to move where - // - // cell i moves to ids[i] - // - // if ids[i] == i || ids[i] == n_kv, then cell i is not moved - // - auto & ids = defrag_info.ids; - - ids.clear(); - ids.resize(n_kv, n_kv); - - for (uint32_t i0 = 0; i0 < n_used; ++i0) { - if (!cells.is_empty(i0)) { - ids[i0] = i0; - - continue; - } - - // found a hole - fill it with data from the end of the cache - - uint32_t nh = 1; - - // determine the size of the hole - while (i0 + nh < n_used && cells.is_empty(i0 + nh)) { - nh++; - } - - uint32_t nf = 0; - uint32_t is = n_kv - 1; - - // starting from the end, find nh non-empty cells - for (; is > i0; --is) { - if (cells.is_empty(is) || ids[is] != n_kv) { - continue; - } - - // non-empty cell which is not yet moved - nf++; - - if (nf == nh) { - break; - } - } - - // this can only happen if `n_used` is not accurate, which would be a bug - GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh"); - - nf = 0; - - uint32_t i1 = is; - - // are we moving a continuous block of memory? - bool cont = false; - - // should we stop searching for the next move? - bool stop = false; - - // go back and move the nf cells to the hole - for (; i1 < n_kv; ++i1) { - if (cells.is_empty(i1) || ids[i1] != n_kv) { - if (n_moves == max_moves) { - stop = true; - break; - } - - cont = false; - continue; - } - - // this cell goes to (i0 + nf) - ids[i1] = i0 + nf; - - // move the cell meta data - cells.mv(i1, i0 + nf); - - head = n_used; - - if (!cont) { - n_moves++; - cont = true; - } - - nf++; - - if (nf == nh) { - break; - } - } - - if (stop || n_moves == max_moves) { - break; - } - - //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh); - - i0 += nh - 1; - } - - if (n_moves == 0) { - return false; - } - - LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); - - LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); - - return true; -} - -bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { - assert(p0 >= 0 && p1 >= 0); - - switch (swa_type) { - case LLAMA_SWA_TYPE_NONE: - { - } break; - case LLAMA_SWA_TYPE_STANDARD: - { - if (p1 - p0 >= (int32_t) n_swa) { - return true; - } - } break; - case LLAMA_SWA_TYPE_CHUNKED: - { - const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; - - if (p0 < pos_chunk_start) { - return true; - } - } break; - } - - return false; -} - -void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - std::vector> cell_ranges; // ranges, from inclusive, to exclusive - uint32_t cell_count = 0; - - // Count the number of cells with the specified seq_id - // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = cells.size(); - - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { - ++cell_count; - if (cell_range_begin == cells.size()) { - cell_range_begin = i; - } - } else { - if (cell_range_begin != cells.size()) { - cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = cells.size(); - } - } - } - - if (cell_range_begin != cells.size()) { - cell_ranges.emplace_back(cell_range_begin, cells.size()); - } - - // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count - uint32_t cell_count_check = 0; - for (const auto & range : cell_ranges) { - cell_count_check += range.second - range.first; - } - GGML_ASSERT(cell_count == cell_count_check); - - io.write(&cell_count, sizeof(cell_count)); - - state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); -} - -void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); - - bool res = true; - res = res && state_read_meta(io, cell_count, seq_id); - res = res && state_read_data(io, cell_count); - - if (!res) { - if (seq_id == -1) { - clear(); - } else { - seq_rm(seq_id, -1, -1); - } - throw std::runtime_error("failed to restore kv cache"); - } -} - -void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { - for (const auto & range : cell_ranges) { - for (uint32_t i = range.first; i < range.second; ++i) { - std::vector seq_ids; - - for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) { - if (cur == seq_id || seq_id == -1) { - if (cells.seq_has(i, cur)) { - seq_ids.push_back(cur); - } - } - } - - const llama_pos pos = cells.pos_get(i); - const uint32_t n_seq_id = seq_ids.size(); - - io.write(&pos, sizeof(pos)); - io.write(&n_seq_id, sizeof(n_seq_id)); - - for (const auto & seq_id : seq_ids) { - io.write(&seq_id, sizeof(seq_id)); - } - } - } -} - -void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { - const uint32_t v_trans = this->v_trans ? 1 : 0; - const uint32_t n_layer = layers.size(); - - io.write(&v_trans, sizeof(v_trans)); - io.write(&n_layer, sizeof(n_layer)); - - std::vector tmp_buf; - - // Iterate and write all the keys first, each row is a cell - // Get whole range at a time - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - - // Write key type - const int32_t k_type_i = (int32_t)layer.k->type; - io.write(&k_type_i, sizeof(k_type_i)); - - // Write row size of key - const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); - io.write(&k_size_row, sizeof(k_size_row)); - - // Read each range of cells of k_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t buf_size = range_size * k_size_row; - io.write_tensor(layer.k, range.first * k_size_row, buf_size); - } - } - - if (!v_trans) { - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Write value type - const int32_t v_type_i = (int32_t)layer.v->type; - io.write(&v_type_i, sizeof(v_type_i)); - - // Write row size of value - const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); - io.write(&v_size_row, sizeof(v_size_row)); - - // Read each range of cells of v_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t buf_size = range_size * v_size_row; - io.write_tensor(layer.v, range.first * v_size_row, buf_size); - } - } - } else { - // When v is transposed, we also need the element size and get the element ranges from each row - const uint32_t kv_size = cells.size(); - - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Write value type - const int32_t v_type_i = (int32_t)layer.v->type; - io.write(&v_type_i, sizeof(v_type_i)); - - // Write element size - const uint32_t v_size_el = ggml_type_size(layer.v->type); - io.write(&v_size_el, sizeof(v_size_el)); - - // Write GQA embedding size - io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); - - // For each row, we get the element values of each cell - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t src_offset = (range.first + j * kv_size) * v_size_el; - const size_t buf_size = range_size * v_size_el; - io.write_tensor(layer.v, src_offset, buf_size); - } - } - } - } -} - -bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { - if (dest_seq_id != -1) { - // single sequence - - seq_rm(dest_seq_id, -1, -1); - - llama_sbatch sbatch; - llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); - - batch.n_tokens = cell_count; - - for (uint32_t i = 0; i < cell_count; ++i) { - llama_pos pos; - uint32_t n_seq_id; - - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); - - if (n_seq_id != 1) { - LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); - return false; - } - - // read the sequence id, but directly discard it - we will use dest_seq_id instead - { - llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); - } - - batch.pos[i] = pos; - batch.n_seq_id[i] = n_seq_id; - batch.seq_id[i] = &dest_seq_id; - } - - const auto head_cur = find_slot(batch); - if (head_cur < 0) { - LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); - return false; - } - - apply_ubatch(head_cur, batch); - - // keep the head at the old position because we will read the KV data into it in state_read_data() - head = head_cur; - - // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) - // Assume that this is one contiguous block of cells - GGML_ASSERT(head_cur + cell_count <= cells.size()); - GGML_ASSERT(cells.pos_get(head_cur) == batch.pos[0]); - GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]); - GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id)); - GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id)); - } else { - // whole KV cache restore - - if (cell_count > cells.size()) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); - return false; - } - - clear(); - - for (uint32_t i = 0; i < cell_count; ++i) { - llama_pos pos; - uint32_t n_seq_id; - - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); - - cells.pos_set(i, pos); - - for (uint32_t j = 0; j < n_seq_id; ++j) { - llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); - - if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max); - return false; - } - - cells.seq_add(i, seq_id); - } - } - - head = 0; - } - - return true; -} - -bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { - uint32_t v_trans; - uint32_t n_layer; - - io.read_to(&v_trans, sizeof(v_trans)); - io.read_to(&n_layer, sizeof(n_layer)); - - if (n_layer != layers.size()) { - LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); - return false; - } - - if (cell_count > cells.size()) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size()); - return false; - } - - if (this->v_trans != (bool) v_trans) { - LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); - return false; - } - - // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - - // Read type of key - int32_t k_type_i_ref; - io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); - const int32_t k_type_i = (int32_t) layer.k->type; - if (k_type_i != k_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); - return false; - } - - // Read row size of key - uint64_t k_size_row_ref; - io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); - if (k_size_row != k_size_row_ref) { - LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); - return false; - } - - if (cell_count) { - // Read and set the keys for the whole cell range - ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); - } - } - - if (!this->v_trans) { - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Read type of value - int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)layer.v->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; - } - - // Read row size of value - uint64_t v_size_row_ref; - io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); - const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); - if (v_size_row != v_size_row_ref) { - LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); - return false; - } - - if (cell_count) { - // Read and set the values for the whole cell range - ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); - } - } - } else { - // For each layer, read the values for each cell (transposed) - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Read type of value - int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)layer.v->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; - } - - // Read element size of value - uint32_t v_size_el_ref; - io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); - const size_t v_size_el = ggml_type_size(layer.v->type); - if (v_size_el != v_size_el_ref) { - LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); - return false; - } - - // Read GQA embedding size - uint32_t n_embd_v_gqa_ref; - io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); - if (n_embd_v_gqa != n_embd_v_gqa_ref) { - LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); - return false; - } - - if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (head + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); - } - } - } - } - - return true; -} - -// -// llama_kv_cache_unified_state -// - -llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} - -llama_kv_cache_unified_state::llama_kv_cache_unified_state( - llama_memory_status status, - llama_kv_cache_unified * kv) : status(status), kv(kv) { - n_kv = kv->get_size(); - head = 0; - } - -llama_kv_cache_unified_state::llama_kv_cache_unified_state( - llama_memory_status status, - llama_kv_cache_unified * kv, - llama_sbatch sbatch, - std::vector heads, - std::vector ubatches) - : status(status), - kv(kv), - sbatch(std::move(sbatch)), - heads(std::move(heads)), - ubatches(std::move(ubatches)) { - } - -llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; - -bool llama_kv_cache_unified_state::next() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - - if (++i_next >= ubatches.size()) { - return false; - } - - return true; -} - -bool llama_kv_cache_unified_state::apply() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - - kv->apply_ubatch(heads[i_next], ubatches[i_next]); - - n_kv = kv->get_n_kv(); - head = heads[i_next]; - - return true; -} - -std::vector & llama_kv_cache_unified_state::out_ids() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - - return sbatch.out_ids; -} - -llama_memory_status llama_kv_cache_unified_state::get_status() const { - return status; -} - -const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - - return ubatches[i_next]; -} - -uint32_t llama_kv_cache_unified_state::get_n_kv() const { - return n_kv; -} - -ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const { - return kv->get_k(ctx, il, n_kv); -} - -ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const { - return kv->get_v(ctx, il, n_kv); -} - -ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { - return kv->cpy_k(ctx, k_cur, il, head); -} - -ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { - return kv->cpy_v(ctx, v_cur, il, head); -} - -void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const { - kv->set_input_k_shift(dst); -} - -void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { - kv->set_input_kq_mask(dst, ubatch, causal_attn); -} - -void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { - kv->set_input_pos_bucket(dst, ubatch); -} - -uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { - // the FA kernels require padding to avoid extra runtime boundary checks - return cparams.flash_attn ? 256u : 32u; -} - -// -// llama_kv_cache_unified_iswa -// - -llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - bool swa_full, - uint32_t kv_size, - uint32_t n_seq_max, - uint32_t n_ubatch, - uint32_t n_pad) : hparams(model.hparams) { - llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; - llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; - - const uint32_t size_base = kv_size; - - uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad)); - - // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size - if (swa_full) { - LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n", - __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); - - size_swa = size_base; - } - - LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); - - kv_base = std::make_unique( - model, std::move(filter_base), type_k, type_v, - v_trans, offload, size_base, n_seq_max, n_pad, - 0, LLAMA_SWA_TYPE_NONE); - - LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); - - kv_swa = std::make_unique( - model, std::move(filter_swa), type_k, type_v, - v_trans, offload, size_swa, n_seq_max, n_pad, - hparams.n_swa, hparams.swa_type); -} - -void llama_kv_cache_unified_iswa::clear() { - kv_base->clear(); - kv_swa ->clear(); -} - -bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - bool res = true; - - res = res & kv_base->seq_rm(seq_id, p0, p1); - res = res & kv_swa ->seq_rm(seq_id, p0, p1); - - return res; -} - -void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1); - kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1); -} - -void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) { - kv_base->seq_keep(seq_id); - kv_swa ->seq_keep(seq_id); -} - -void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { - kv_base->seq_add(seq_id, p0, p1, shift); - kv_swa ->seq_add(seq_id, p0, p1, shift); -} - -void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - kv_base->seq_div(seq_id, p0, p1, d); - kv_swa ->seq_div(seq_id, p0, p1, d); -} - -llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const { - // the base cache is a superset of the SWA cache, so we can just check the SWA cache - return kv_swa->seq_pos_min(seq_id); -} - -llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { - return kv_swa->seq_pos_max(seq_id); -} - -llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { - GGML_UNUSED(embd_pooled); - - // TODO: if we fail with split_simple, we should attempt different splitting strategies - // but to do that properly, we first have to refactor the batches to be more flexible - - auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); - - std::vector ubatches; - - while (sbatch.n_tokens > 0) { - auto ubatch = sbatch.split_simple(n_ubatch); - - ubatches.push_back(ubatch); - } - - auto heads_base = kv_base->prepare(ubatches); - if (heads_base.empty()) { - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); - } - - auto heads_swa = kv_swa->prepare(ubatches); - if (heads_swa.empty()) { - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); - } - - assert(heads_base.size() == heads_swa.size()); - - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, - this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); -} - -llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); -} - -bool llama_kv_cache_unified_iswa::update(llama_context & lctx) { - bool res = false; - - res = res | kv_base->update(lctx); - res = res | kv_swa ->update(lctx); - - return res; -} - -void llama_kv_cache_unified_iswa::defrag_sched(float thold) { - kv_base->defrag_sched(thold); - kv_swa ->defrag_sched(thold); -} - -bool llama_kv_cache_unified_iswa::get_can_shift() const { - return kv_base->get_size() == kv_swa->get_size(); -} - -void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - kv_base->state_write(io, seq_id); - kv_swa ->state_write(io, seq_id); -} - -void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - kv_base->state_read(io, seq_id); - kv_swa ->state_read(io, seq_id); -} - -llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const { - return kv_base.get(); -} - -llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const { - return kv_swa.get(); -} - -// -// llama_kv_cache_unified_iswa_state -// - -llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} - -llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( - llama_memory_status status, - llama_kv_cache_unified_iswa * kv) : status(status) { - state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base())); - state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ())); -} - -llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( - llama_memory_status status, - llama_kv_cache_unified_iswa * kv, - llama_sbatch sbatch, - std::vector heads_base, - std::vector heads_swa, - std::vector ubatches) - : status(status), - sbatch(std::move(sbatch)), - ubatches(std::move(ubatches)) { - // note: here we copy the ubatches. not sure if this is ideal - state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches)); - state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches)); - } - -llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; - -bool llama_kv_cache_unified_iswa_state::next() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - - state_base->next(); - state_swa ->next(); - - if (++i_next >= ubatches.size()) { - return false; - } - - return true; -} - -bool llama_kv_cache_unified_iswa_state::apply() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - - bool res = true; - - res = res & state_base->apply(); - res = res & state_swa ->apply(); - - return res; -} - -std::vector & llama_kv_cache_unified_iswa_state::out_ids() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - - return sbatch.out_ids; -} - -llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const { - return status; -} - -const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return ubatches[i_next]; -} - -const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - - return state_base.get(); -} - -const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - - return state_swa.get(); -} - -// -// llama_kv_cache_recurrent -// - -llama_kv_cache_recurrent::llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { - const int32_t n_layer = hparams.n_layer; - - LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n", - __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer); - - head = 0; - size = kv_size; - used = 0; - - cells.clear(); - cells.resize(kv_size); - - // create a context for each buffer type - std::map ctx_map; - auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { - auto it = ctx_map.find(buft); - if (it == ctx_map.end()) { - ggml_init_params params = { - /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - ggml_context * ctx = ggml_init(params); - if (!ctx) { - return nullptr; - } - - ctx_map[buft] = ctx; - ctxs.emplace_back(ctx); - - return ctx; - } - - return it->second; - }; - - k_l.reserve(n_layer); - v_l.reserve(n_layer); - - for (int i = 0; i < n_layer; i++) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - - const char * dev_name = "CPU"; - - ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); - - if (offload) { - auto * dev = model.dev_layer(i); - buft = ggml_backend_dev_buffer_type(dev); - - dev_name = ggml_backend_dev_name(dev); - } - - LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name); - - ggml_context * ctx = ctx_for_buft(buft); - if (!ctx) { - throw std::runtime_error("failed to create ggml context for kv cache"); - } - - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); - ggml_format_name(k, "cache_k_l%d", i); - ggml_format_name(v, "cache_v_l%d", i); - k_l.push_back(k); - v_l.push_back(v); - } - - // allocate tensors and initialize the buffers to avoid NaNs in the padding - for (auto it : ctx_map) { - auto * buft = it.first; - auto * ctx = it.second; - - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); - if (!buf) { - throw std::runtime_error("failed to allocate buffer for kv cache"); - } - ggml_backend_buffer_clear(buf, 0); - LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); - bufs.emplace_back(buf); - } - - { - const size_t memory_size_k = size_k_bytes(); - const size_t memory_size_v = size_v_bytes(); - - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), - ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); - } -} - -void llama_kv_cache_recurrent::clear() { - for (int32_t i = 0; i < (int32_t) size; ++i) { - cells[i].pos = -1; - cells[i].seq_id.clear(); - cells[i].src = -1; - cells[i].tail = -1; - } - head = 0; - used = 0; - - for (auto & buf : bufs) { - ggml_backend_buffer_clear(buf.get(), 0); - } -} - -bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - uint32_t new_head = size; - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // models like Mamba or RWKV can't have a state partially erased - if (seq_id >= (int64_t) size) { - // could be fatal - return false; - } - if (0 <= seq_id) { - int32_t & tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - const kv_cell & cell = cells[tail_id]; - // partial intersection is invalid - if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { - return false; - } - // invalidate tails which will be cleared - if (p0 <= cell.pos && cell.pos < p1) { - tail_id = -1; - } - } - } else { - // seq_id is negative, then the range should include everything or nothing - if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { - return false; - } - } - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].pos >= p0 && cells[i].pos < p1) { - if (seq_id < 0) { - cells[i].seq_id.clear(); - } else if (cells[i].has_seq_id(seq_id)) { - cells[i].seq_id.erase(seq_id); - } else { - continue; - } - if (cells[i].is_empty()) { - // keep count of the number of used cells - if (cells[i].pos >= 0) { - used--; - } - cells[i].pos = -1; - cells[i].src = -1; - if (new_head == size) { - new_head = i; - } - } - } - } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != size && new_head < head) { - head = new_head; - } - - return true; -} - -void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - if (seq_id_src == seq_id_dst) { - return; - } - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { - kv_cell & tail_src = cells[seq_id_src]; - kv_cell & tail_dst = cells[seq_id_dst]; - if (tail_dst.tail >= 0) { - // clear destination seq_id if it wasn't empty - kv_cell & cell_dst = cells[tail_dst.tail]; - - cell_dst.seq_id.erase(seq_id_dst); - tail_dst.tail = -1; - if (cell_dst.seq_id.empty()) { - cell_dst.pos = -1; - cell_dst.src = -1; - used -= 1; - } - } - if (tail_src.tail >= 0) { - kv_cell & cell_src = cells[tail_src.tail]; - - cell_src.seq_id.insert(seq_id_dst); - tail_dst.tail = tail_src.tail; - } - } -} - -void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) { - uint32_t new_head = size; - - for (uint32_t i = 0; i < size; ++i) { - if ((llama_seq_id) i != seq_id) { - cells[i].tail = -1; - } - - if (!cells[i].has_seq_id(seq_id)) { - if (cells[i].pos >= 0) { - used--; - } - - cells[i].pos = -1; - cells[i].src = -1; - cells[i].seq_id.clear(); - - if (new_head == size){ - new_head = i; - } - } else { - cells[i].seq_id.clear(); - cells[i].seq_id.insert(seq_id); - } - } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != size && new_head < head) { - head = new_head; - } -} - -void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { - if (shift == 0) { - return; - } - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // If there is no range then return early to avoid looping over the - if (p0 == p1) { - return; - } - - // for Mamba-like or RWKV models, only the pos needs to be shifted - if (0 <= seq_id && seq_id < (int64_t) size) { - const int32_t tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - kv_cell & cell = cells[tail_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos += shift; - } - } - } -} - -void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - if (d == 1) { - return; - } - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // If there is no range then return early to avoid looping over the cache. - if (p0 == p1) { - return; - } - - // for Mamba-like or RWKV models, only the pos needs to be changed - if (0 <= seq_id && seq_id < (int64_t) size) { - const int32_t tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - kv_cell & cell = cells[tail_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos /= d; - } - } - } -} - -llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const { - llama_pos result = std::numeric_limits::max(); - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id)) { - result = std::min(result, cells[i].pos); - } - } - - if (result == std::numeric_limits::max()) { - result = -1; - } - - return result; -} - -llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { - llama_pos result = -1; - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id)) { - result = std::max(result, cells[i].pos); - } - } - - return result; -} - -llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { - GGML_UNUSED(embd_pooled); - - auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); - - std::vector ubatches; - - while (sbatch.n_tokens > 0) { - llama_ubatch ubatch; - - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - ubatch = sbatch.split_seq(n_ubatch); - } else { - ubatch = sbatch.split_equal(n_ubatch); - } - - ubatches.push_back(ubatch); - } - - if (!prepare(ubatches)) { - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); - } - - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches)); -} - -llama_memory_state_ptr llama_kv_cache_recurrent::init_full() { - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); -} - -bool llama_kv_cache_recurrent::prepare(const std::vector & ubatches) { - // simply remember the full state because it is very small for this type of cache - // TODO: optimize - auto org_cells = cells; - auto org_used = used; - auto org_head = head; - - bool success = true; - - // TODO: here we have to verify that all ubatches can fit in the cells - // however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells - // during the compute of each ubatch. to reproduce, uncomment the following loop and run: - // - // $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8 - // - // recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed - // - GGML_UNUSED(ubatches); - //for (const auto & ubatch : ubatches) { - // if (!find_slot(ubatch)) { - // success = false; - // break; - // } - //} - - // restore the original state - cells = std::move(org_cells); - used = org_used; - head = org_head; - - return success; -} - -bool llama_kv_cache_recurrent::update(llama_context & lctx) { - GGML_UNUSED(lctx); - // noop - return false; -} - -void llama_kv_cache_recurrent::defrag_sched(float thold) { - GGML_UNUSED(thold); - // noop -} - -bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { - const uint32_t n_tokens = ubatch.n_tokens; - const uint32_t n_seqs = ubatch.n_seqs; - - const uint32_t n_seq_tokens = ubatch.n_seq_tokens; - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (head > used + 2*n_tokens) { - head = 0; - } - - // For recurrent state architectures (like Mamba or RWKV), - // each cache cell can store the state for a whole sequence. - // A slot should be always be contiguous. - - // can only process batches with an equal number of new tokens in each sequence - GGML_ASSERT(ubatch.equal_seqs); - - int32_t min = size - 1; - int32_t max = 0; - - // everything should fit if all seq_ids are smaller than the max - for (uint32_t s = 0; s < n_seqs; ++s) { - const uint32_t n_seq_id = ubatch.n_seq_id[s]; - for (uint32_t j = 0; j < n_seq_id; ++j) { - const llama_seq_id seq_id = ubatch.seq_id[s][j]; - - if (seq_id < 0 || (uint32_t) seq_id >= size) { - // too big seq_id - // TODO: would it be possible to resize the cache instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max); - return false; - } - if (j > 0) { - kv_cell & seq = cells[seq_id]; - if (seq.tail >= 0) { - kv_cell & cell = cells[seq.tail]; - // clear cells from seq_ids that become shared - // (should not normally happen, but let's handle it anyway) - cell.seq_id.erase(seq_id); - seq.tail = -1; - if (cell.seq_id.empty()) { - cell.pos = -1; - cell.src = -1; - used -= 1; - } - } - } - } - } - -#ifndef NDEBUG - { - std::vector tails_verif; - tails_verif.assign(size, -1); - for (uint32_t i = 0; i < size; ++i) { - kv_cell & cell = cells[i]; - for (llama_seq_id seq_id : cell.seq_id) { - if (tails_verif[seq_id] != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); - } - tails_verif[seq_id] = i; - } - } - for (uint32_t i = 0; i < size; ++i) { - if (tails_verif[i] != cells[i].tail) { - LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]); - } - } - } -#endif - - // find next empty cell - uint32_t next_empty_cell = head; - - for (uint32_t i = 0; i < size; ++i) { - if (next_empty_cell >= size) { next_empty_cell -= size; } - kv_cell & cell = cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; - } - - // find usable cell range - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - kv_cell & seq_meta = cells[seq_id]; - bool has_cell = false; - if (seq_meta.tail >= 0) { - kv_cell & cell = cells[seq_meta.tail]; - GGML_ASSERT(cell.has_seq_id(seq_id)); - // does this seq_id "own" the cell? - if (cell.seq_id.size() == 1) { has_cell = true; } - } - if (!has_cell) { - kv_cell & empty_cell = cells[next_empty_cell]; - GGML_ASSERT(empty_cell.is_empty()); - // copy old tail into the empty cell - if (seq_meta.tail >= 0) { - kv_cell & orig_cell = cells[seq_meta.tail]; - empty_cell.pos = orig_cell.pos; - empty_cell.src = orig_cell.src; - orig_cell.seq_id.erase(seq_id); - empty_cell.seq_id.insert(seq_id); // will be overwritten - } - seq_meta.tail = next_empty_cell; - // find next empty cell - if (s + 1 < n_seqs) { - next_empty_cell += 1; - for (uint32_t i = 0; i < size; ++i) { - if (next_empty_cell >= size) { next_empty_cell -= size; } - kv_cell & cell = cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; - } - } - } - if (min > seq_meta.tail) { min = seq_meta.tail; } - if (max < seq_meta.tail) { max = seq_meta.tail; } - } - - // gather and re-order - for (uint32_t s = 0; s < n_seqs; ++s) { - int32_t dst_id = s + min; - int32_t src_id = cells[ubatch.seq_id[s][0]].tail; - if (dst_id != src_id) { - kv_cell & dst_cell = cells[dst_id]; - kv_cell & src_cell = cells[src_id]; - - std::swap(dst_cell.pos, src_cell.pos); - std::swap(dst_cell.src, src_cell.src); - std::swap(dst_cell.seq_id, src_cell.seq_id); - - // swap tails (assuming they NEVER overlap) - for (const llama_seq_id seq_id : src_cell.seq_id) { - cells[seq_id].tail = src_id; - } - for (const llama_seq_id seq_id : dst_cell.seq_id) { - cells[seq_id].tail = dst_id; - } - } - } - - // update the pos of the used seqs - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; - int32_t cell_id = s + min; - kv_cell & cell = cells[cell_id]; - - if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", - __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); - } - cell.pos = last_pos; - cell.seq_id.clear(); - for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { - const llama_seq_id seq_id = ubatch.seq_id[s][j]; - cell.seq_id.insert(seq_id); - cells[seq_id].tail = cell_id; - } - } - - // allow getting the range of used cells, from head to head + n - head = min; - n = max - min + 1; - used = std::count_if(cells.begin(), cells.end(), - [](const kv_cell & cell){ return !cell.is_empty(); }); - - // sanity check - return n >= n_seqs; -} - -bool llama_kv_cache_recurrent::get_can_shift() const { - return false; -} - -int32_t llama_kv_cache_recurrent::s_copy(int i) const { - const uint32_t cell_id = i + head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - kv_cell & cell = const_cast(cells[cell_id]); - - // prevent out-of-bound sources - if (cell.src < 0 || (uint32_t) cell.src >= size) { - cell.src = cell_id; - } - - int32_t res = cell.src; - - // TODO: do not mutate the KV cache - // ensure copy only happens once - if (cell.src != (int32_t) cell_id) { - cell.src = cell_id; - } - - return res; -} - -float llama_kv_cache_recurrent::s_mask(int i) const { - const uint32_t cell_id = i + head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - kv_cell & cell = const_cast(cells[cell_id]); - - float res = (float) (cell.src >= 0); - - // only clear once - if (cell.src < 0) { - cell.src = cell_id; - } - - return res; -} - -size_t llama_kv_cache_recurrent::total_size() const { - size_t size = 0; - for (const auto & buf : bufs) { - size += ggml_backend_buffer_get_size(buf.get()); - } - - return size; -} - -size_t llama_kv_cache_recurrent::size_k_bytes() const { - size_t size_k_bytes = 0; - - for (const auto & k : k_l) { - size_k_bytes += ggml_nbytes(k); - } - - return size_k_bytes; -} - -size_t llama_kv_cache_recurrent::size_v_bytes() const { - size_t size_v_bytes = 0; - - for (const auto & v : v_l) { - size_v_bytes += ggml_nbytes(v); - } - - return size_v_bytes; -} - -void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - std::vector> cell_ranges; // ranges, from inclusive, to exclusive - uint32_t cell_count = 0; - - // Count the number of cells with the specified seq_id - // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = size; - for (uint32_t i = 0; i < size; ++i) { - const auto & cell = cells[i]; - if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { - ++cell_count; - if (cell_range_begin == size) { - cell_range_begin = i; - } - } else { - if (cell_range_begin != size) { - cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = size; - } - } - } - if (cell_range_begin != size) { - cell_ranges.emplace_back(cell_range_begin, size); - } - - // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count - uint32_t cell_count_check = 0; - for (const auto & range : cell_ranges) { - cell_count_check += range.second - range.first; - } - GGML_ASSERT(cell_count == cell_count_check); - - io.write(&cell_count, sizeof(cell_count)); - - state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); -} - -void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); - - bool res = true; - - res = res && state_read_meta(io, cell_count, seq_id); - res = res && state_read_data(io, cell_count); - - if (!res) { - if (seq_id == -1) { - clear(); - } else { - seq_rm(seq_id, -1, -1); - } - throw std::runtime_error("failed to restore kv cache"); - } -} - -void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { - for (const auto & range : cell_ranges) { - for (uint32_t i = range.first; i < range.second; ++i) { - const auto & cell = cells[i]; - const llama_pos pos = cell.pos; - const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; - - io.write(&pos, sizeof(pos)); - io.write(&n_seq_id, sizeof(n_seq_id)); - - if (n_seq_id) { - for (auto seq_id : cell.seq_id) { - io.write(&seq_id, sizeof(seq_id)); - } - } - } - } -} - -void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { - const uint32_t v_trans = 0; - const uint32_t n_layer = hparams.n_layer; - - io.write(&v_trans, sizeof(v_trans)); - io.write(&n_layer, sizeof(n_layer)); - - std::vector tmp_buf; - - // Iterate and write all the keys first, each row is a cell - // Get whole range at a time - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - - // Write key type - const int32_t k_type_i = (int32_t)k_l[il]->type; - io.write(&k_type_i, sizeof(k_type_i)); - - // Write row size of key - const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - io.write(&k_size_row, sizeof(k_size_row)); - - // Read each range of cells of k_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t buf_size = range_size * k_size_row; - io.write_tensor(k_l[il], range.first * k_size_row, buf_size); - } - } - - if (!v_trans) { - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Write value type - const int32_t v_type_i = (int32_t)v_l[il]->type; - io.write(&v_type_i, sizeof(v_type_i)); - - // Write row size of value - const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); - io.write(&v_size_row, sizeof(v_size_row)); - - // Read each range of cells of v_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t buf_size = range_size * v_size_row; - io.write_tensor(v_l[il], range.first * v_size_row, buf_size); - } - } - } else { - // When v is transposed, we also need the element size and get the element ranges from each row - const uint32_t kv_size = size; - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Write value type - const int32_t v_type_i = (int32_t)v_l[il]->type; - io.write(&v_type_i, sizeof(v_type_i)); - - // Write element size - const uint32_t v_size_el = ggml_type_size(v_l[il]->type); - io.write(&v_size_el, sizeof(v_size_el)); - - // Write GQA embedding size - io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); - - // For each row, we get the element values of each cell - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t src_offset = (range.first + j * kv_size) * v_size_el; - const size_t buf_size = range_size * v_size_el; - io.write_tensor(v_l[il], src_offset, buf_size); - } - } - } - } -} - -bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { - if (dest_seq_id != -1) { - // single sequence - - seq_rm(dest_seq_id, -1, -1); - - llama_sbatch sbatch; - llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); - - batch.n_tokens = cell_count; - batch.n_seq_tokens = cell_count; - batch.n_seqs = 1; - - for (uint32_t i = 0; i < cell_count; ++i) { - llama_pos pos; - uint32_t n_seq_id; - - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); - - if (n_seq_id != 0) { - LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); - return false; - } - - batch.pos[i] = pos; - } - batch.n_seq_id[0] = 1; - batch.seq_id[0] = &dest_seq_id; - - if (!find_slot(batch)) { - LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); - return false; - } - - // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) - // Assume that this is one contiguous block of cells - GGML_ASSERT(head + cell_count <= size); - GGML_ASSERT(cells[head].pos == batch.pos[0]); - GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); - GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); - GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); - } else { - // whole KV cache restore - - if (cell_count > size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); - return false; - } - - clear(); - - for (uint32_t i = 0; i < cell_count; ++i) { - kv_cell & cell = cells[i]; - - llama_pos pos; - uint32_t n_seq_id; - - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); - - cell.pos = pos; - - for (uint32_t j = 0; j < n_seq_id; ++j) { - llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); - - // TODO: llama_kv_cache_recurrent should have a notion of max sequences - //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { - if (seq_id < 0) { - //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); - return false; - } - - cell.seq_id.insert(seq_id); - - int32_t & tail = cells[seq_id].tail; - if (tail != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); - return false; - } - tail = i; - } - } - - head = 0; - used = cell_count; - } - - for (uint32_t i = 0; i < cell_count; ++i) { - uint32_t cell_id = head + i; - // make sure the recurrent states will keep their restored state - cells[cell_id].src = cell_id; - } - - return true; -} - -bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) { - uint32_t v_trans; - uint32_t n_layer; - io.read_to(&v_trans, sizeof(v_trans)); - io.read_to(&n_layer, sizeof(n_layer)); - - if (n_layer != hparams.n_layer) { - LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); - return false; - } - if (cell_count > size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); - return false; - } - if (false != (bool) v_trans) { - LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); - return false; - } - - // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - - // Read type of key - int32_t k_type_i_ref; - io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); - const int32_t k_type_i = (int32_t) k_l[il]->type; - if (k_type_i != k_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); - return false; - } - - // Read row size of key - uint64_t k_size_row_ref; - io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - if (k_size_row != k_size_row_ref) { - LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); - return false; - } - - if (cell_count) { - // Read and set the keys for the whole cell range - ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); - } - } - - if (!v_trans) { - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Read type of value - int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)v_l[il]->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; - } - - // Read row size of value - uint64_t v_size_row_ref; - io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); - const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); - if (v_size_row != v_size_row_ref) { - LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); - return false; - } - - if (cell_count) { - // Read and set the values for the whole cell range - ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); - } - } - } else { - // For each layer, read the values for each cell (transposed) - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Read type of value - int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)v_l[il]->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; - } - - // Read element size of value - uint32_t v_size_el_ref; - io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); - const size_t v_size_el = ggml_type_size(v_l[il]->type); - if (v_size_el != v_size_el_ref) { - LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); - return false; - } - - // Read GQA embedding size - uint32_t n_embd_v_gqa_ref; - io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); - if (n_embd_v_gqa != n_embd_v_gqa_ref) { - LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); - return false; - } - - if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (head + j * size) * v_size_el; - ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); - } - } - } - } - - return true; -} - -// -// llama_kv_cache_recurrent_state -// - -llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {} - -llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( - llama_memory_status status, - llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) { -} - -llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( - llama_memory_status status, - llama_kv_cache_recurrent * kv, - llama_sbatch sbatch, - std::vector ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {} - -llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default; - -bool llama_kv_cache_recurrent_state::next() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - - if (++i_next >= ubatches.size()) { - return false; - } - - return true; -} - -bool llama_kv_cache_recurrent_state::apply() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - - kv->find_slot(ubatches[i_next]); - - return true; -} - -std::vector & llama_kv_cache_recurrent_state::out_ids() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - - return sbatch.out_ids; -} - -llama_memory_status llama_kv_cache_recurrent_state::get_status() const { - return status; -} - -const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - - return ubatches[i_next]; -} - -uint32_t llama_kv_cache_recurrent_state::get_n_kv() const { - return is_full ? kv->size : kv->n; -} - -uint32_t llama_kv_cache_recurrent_state::get_head() const { - return is_full ? 0 : kv->head; -} - -uint32_t llama_kv_cache_recurrent_state::get_size() const { - return kv->size; -} - -ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const { - return kv->k_l[il]; -} - -ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const { - return kv->v_l[il]; -} - -int32_t llama_kv_cache_recurrent_state::s_copy(int i) const { - return kv->s_copy(i); -} - -float llama_kv_cache_recurrent_state::s_mask(int i) const { - return kv->s_mask(i); -} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 75969878e318c..2d04705f27857 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -2,21 +2,7 @@ #include "llama.h" #include "llama-io.h" -#include "llama-batch.h" -#include "llama-graph.h" #include "llama-memory.h" -#include "llama-kv-cells.h" - -#include "ggml-cpp.h" - -#include -#include -#include - -struct llama_cparams; -struct llama_hparams; -struct llama_model; -struct llama_context; struct llama_kv_cache : public llama_memory_i { virtual ~llama_kv_cache() = default; @@ -56,581 +42,3 @@ struct llama_kv_cache : public llama_memory_i { virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; }; - -// -// llama_kv_cache_unified -// - -class llama_kv_cache_unified : public llama_kv_cache { -public: - static uint32_t get_padding(const llama_cparams & cparams); - - // this callback is used to filter out layers that should not be included in the cache - using layer_filter_cb = std::function; - - llama_kv_cache_unified( - const llama_model & model, - layer_filter_cb && filter, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max, - uint32_t n_pad, - uint32_t n_swa, - llama_swa_type swa_type); - - ~llama_kv_cache_unified() = default; - - // - // llama_memory_i - // - - void clear() override; - - bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; - void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; - void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; - - llama_pos seq_pos_min(llama_seq_id seq_id) const override; - llama_pos seq_pos_max(llama_seq_id seq_id) const override; - - // - // llama_kv_cache - // - - llama_memory_state_ptr init_batch( - const llama_batch & batch, - uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) override; - - llama_memory_state_ptr init_full() override; - - bool update(llama_context & lctx) override; - - void defrag_sched(float thold) override; - - bool get_can_shift() const override; - - // state write/load - - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - - // - // llama_kv_cache_unified specific API - // - - uint32_t get_size() const; - - // - // graph_build API - // - - uint32_t get_n_kv() const; - - // get views of the current state of the cache - ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const; - ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; - - // store k_cur and v_cur in the cache based on the provided head location - ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const; - ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const; - - // - // preparation API - // - - // find places for the provided ubatches in the cache, returns the head locations - // return empty vector on failure - std::vector prepare(const std::vector & ubatches); - - // return the cell position where we can insert the ubatch - // return -1 on failure to find a contiguous slot of kv cells - int32_t find_slot(const llama_ubatch & ubatch) const; - - // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens) - void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch); - - // - // set_input API - // - - void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; - void set_input_k_shift (ggml_tensor * dst) const; - void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; - -private: - const llama_model & model; - const llama_hparams & hparams; - - struct kv_layer { - // layer index in the model - // note: can be different from the layer index in the KV cache - uint32_t il; - - ggml_tensor * k; - ggml_tensor * v; - }; - - bool do_defrag = false; - bool v_trans = true; // the value tensor is transposed - - // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) - // note: this is not part of the KV state and it's only used to speed-up the find_slot() method - uint32_t head = 0; - - const uint32_t n_seq_max = 1; - - // required padding - const uint32_t n_pad = 1; - - // SWA - const uint32_t n_swa = 0; - - const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; - - std::vector ctxs; - std::vector bufs; - - llama_kv_cells_unified cells; - - std::vector layers; - - // model layer id -> KV cache layer id - std::unordered_map map_layer_ids; - - // defrag - struct { - std::vector ids; - } defrag_info; - - // return true if cells have been moved - bool defrag_prepare(int32_t n_max_nodes); - - size_t total_size() const; - - size_t size_k_bytes() const; - size_t size_v_bytes() const; - - bool is_masked_swa(llama_pos p0, llama_pos p1) const; - - ggml_tensor * build_rope_shift( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale) const; - - llm_graph_result_ptr build_graph_shift( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const; - - llm_graph_result_ptr build_graph_defrag( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const; - - void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; - void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; - - bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); - bool state_read_data(llama_io_read_i & io, uint32_t cell_count); -}; - -class llama_kv_cache_unified_state : public llama_memory_state_i { -public: - // used for errors - llama_kv_cache_unified_state(llama_memory_status status); - - // used to create a full-cache state - llama_kv_cache_unified_state( - llama_memory_status status, - llama_kv_cache_unified * kv); - - // used to create a state from a batch - llama_kv_cache_unified_state( - llama_memory_status status, - llama_kv_cache_unified * kv, - llama_sbatch sbatch, - std::vector heads, - std::vector ubatches); - - virtual ~llama_kv_cache_unified_state(); - - // - // llama_memory_state_i - // - - bool next() override; - bool apply() override; - - std::vector & out_ids() override; - - llama_memory_status get_status() const override; - const llama_ubatch & get_ubatch() const override; - - // - // llama_kv_cache_unified_state specific API - // - - uint32_t get_n_kv() const; - - // get views of the current state of the cache - ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; - ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; - - // store k_cur and v_cur in the cache based on the provided head location - ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; - ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; - - void set_input_k_shift(ggml_tensor * dst) const; - - void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; - void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; - -private: - const llama_memory_status status; - - llama_kv_cache_unified * kv; - - llama_sbatch sbatch; - - // the index of the next ubatch to process - size_t i_next = 0; - - std::vector heads; - std::vector ubatches; - - // - // data needed for building the compute graph for the current ubatch: - // - - // a heuristic, to avoid attending the full cache if it is not yet utilized - // as the cache gets filled, the benefit from this heuristic disappears - int32_t n_kv; - - // the beginning of the current slot in which the ubatch will be inserted - int32_t head; -}; - -// -// llama_kv_cache_unified_iswa -// - -// utilizes two instances of llama_kv_cache_unified -// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers - -class llama_kv_cache_unified_iswa : public llama_kv_cache { -public: - llama_kv_cache_unified_iswa( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - bool swa_full, - uint32_t kv_size, - uint32_t n_seq_max, - uint32_t n_ubatch, - uint32_t n_pad); - - ~llama_kv_cache_unified_iswa() = default; - - // - // llama_memory_i - // - - void clear() override; - - bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; - void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; - void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; - - llama_pos seq_pos_min(llama_seq_id seq_id) const override; - llama_pos seq_pos_max(llama_seq_id seq_id) const override; - - // - // llama_kv_cache - // - - llama_memory_state_ptr init_batch( - const llama_batch & batch, - uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) override; - - llama_memory_state_ptr init_full() override; - - bool update(llama_context & lctx) override; - - void defrag_sched(float thold) override; - - bool get_can_shift() const override; - - // state write/load - - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - - // - // llama_kv_cache_unified_iswa specific API - // - - llama_kv_cache_unified * get_base() const; - llama_kv_cache_unified * get_swa () const; - -private: - const llama_hparams & hparams; - - std::unique_ptr kv_base; - std::unique_ptr kv_swa; -}; - -class llama_kv_cache_unified_iswa_state : public llama_memory_state_i { -public: - // used for errors - llama_kv_cache_unified_iswa_state(llama_memory_status status); - - // used to create a full-cache state - llama_kv_cache_unified_iswa_state( - llama_memory_status status, - llama_kv_cache_unified_iswa * kv); - - // used to create a state from a batch - llama_kv_cache_unified_iswa_state( - llama_memory_status status, - llama_kv_cache_unified_iswa * kv, - llama_sbatch sbatch, - std::vector heads_base, - std::vector heads_swa, - std::vector ubatches); - - virtual ~llama_kv_cache_unified_iswa_state(); - - // - // llama_memory_state_i - // - - bool next() override; - bool apply() override; - - std::vector & out_ids() override; - - llama_memory_status get_status() const override; - const llama_ubatch & get_ubatch() const override; - - // - // llama_kv_cache_unified_iswa_state specific API - // - - const llama_kv_cache_unified_state * get_base() const; - const llama_kv_cache_unified_state * get_swa() const; - -private: - const llama_memory_status status; - - //llama_kv_cache_unified_iswa * kv; - - llama_sbatch sbatch; - - // the index of the next ubatch to process - size_t i_next = 0; - - std::vector ubatches; - - std::unique_ptr state_base; - std::unique_ptr state_swa; -}; - -// -// llama_kv_cache_recurrent -// - -// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i -// see the implementation of llama_kv_cache_unified_state_i for an example how to do it -class llama_kv_cache_recurrent : public llama_kv_cache { -public: - llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max); - - ~llama_kv_cache_recurrent() = default; - - // - // llama_memory_i - // - - void clear() override; - - bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; - void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; - void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; - - llama_pos seq_pos_min(llama_seq_id seq_id) const override; - llama_pos seq_pos_max(llama_seq_id seq_id) const override; - - // - // llama_kv_cache - // - - llama_memory_state_ptr init_batch( - const llama_batch & batch, - uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) override; - - llama_memory_state_ptr init_full() override; - - bool update(llama_context & lctx) override; - - void defrag_sched(float thold) override; - - bool prepare(const std::vector & ubatches); - - // find a contiguous slot of kv cells and emplace the ubatch there - bool find_slot(const llama_ubatch & ubatch); - - bool get_can_shift() const override; - - // TODO: temporary methods - they are not really const as they do const_cast<>, fix this - int32_t s_copy(int i) const; - float s_mask(int i) const; - - // state write/load - - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - - uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) - uint32_t size = 0; // total number of cells, shared across all sequences - uint32_t used = 0; // used cells (i.e. at least one seq_id) - - // computed before each graph build - uint32_t n = 0; - - // TODO: optimize for recurrent state needs - struct kv_cell { - llama_pos pos = -1; - int32_t src = -1; // used to copy states - int32_t tail = -1; - - std::set seq_id; - - bool has_seq_id(const llama_seq_id & id) const { - return seq_id.find(id) != seq_id.end(); - } - - bool is_empty() const { - return seq_id.empty(); - } - - bool is_same_seq(const kv_cell & other) const { - return seq_id == other.seq_id; - } - }; - - std::vector cells; - - std::vector k_l; // per layer - std::vector v_l; - -private: - //const llama_model & model; - const llama_hparams & hparams; - - const uint32_t n_seq_max = 1; - - std::vector ctxs; - std::vector bufs; - - size_t total_size() const; - - size_t size_k_bytes() const; - size_t size_v_bytes() const; - - void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; - void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; - - bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); - bool state_read_data(llama_io_read_i & io, uint32_t cell_count); -}; - -class llama_kv_cache_recurrent_state : public llama_memory_state_i { -public: - // used for errors - llama_kv_cache_recurrent_state(llama_memory_status status); - - // used to create a full-cache state - llama_kv_cache_recurrent_state( - llama_memory_status status, - llama_kv_cache_recurrent * kv); - - // used to create a state from a batch - llama_kv_cache_recurrent_state( - llama_memory_status status, - llama_kv_cache_recurrent * kv, - llama_sbatch sbatch, - std::vector ubatches); - - virtual ~llama_kv_cache_recurrent_state(); - - // - // llama_memory_state_i - // - - bool next() override; - bool apply() override; - - std::vector & out_ids() override; - - llama_memory_status get_status() const override; - const llama_ubatch & get_ubatch() const override; - - // - // llama_kv_cache_recurrent_state specific API - // - - uint32_t get_n_kv() const; - uint32_t get_head() const; - uint32_t get_size() const; - - ggml_tensor * get_k_l(int32_t il) const; - ggml_tensor * get_v_l(int32_t il) const; - - int32_t s_copy(int i) const; - float s_mask(int i) const; - -private: - const llama_memory_status status; - - llama_kv_cache_recurrent * kv; - - llama_sbatch sbatch; - - size_t i_next = 0; - - std::vector ubatches; - - // - // data needed for building the compute graph for the current ubatch: - // TODO: extract all the state like `head` and `n` here - // - - const bool is_full = false; -}; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 44c07ff457dce..afef8487030fb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -5,7 +5,10 @@ #include "llama-batch.h" #include "llama-cparams.h" #include "llama-model-loader.h" -#include "llama-kv-cache.h" + +#include "llama-kv-cache-unified.h" +#include "llama-kv-cache-unified-iswa.h" +#include "llama-kv-cache-recurrent.h" #include "ggml-cpp.h" @@ -953,6 +956,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 46: type = LLM_TYPE_27B; break; default: type = LLM_TYPE_UNKNOWN; } + + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173 + hparams.f_attention_scale = type == LLM_TYPE_27B + ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) + : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); } break; case LLM_ARCH_GEMMA3: { @@ -973,6 +981,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289 hparams.f_attention_scale = type == LLM_TYPE_27B ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); @@ -8481,14 +8490,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e - switch (model.type) { - case LLM_TYPE_2B: - case LLM_TYPE_9B: - case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break; - default: GGML_ABORT("fatal error"); - }; - cb(Qcur, "Qcur_scaled", il); + Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, @@ -8629,9 +8631,12 @@ struct llm_build_gemma3_iswa : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315 + Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); + cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } cur = build_norm(cur, diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index d5a036a8c4413..b51976699ca7b 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -2080,9 +2080,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { std::string model_name; std::string tokenizer_pre; + std::string general_arch; ml.get_key(LLM_KV_GENERAL_NAME, model_name, false); ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); + ml.get_key(LLM_KV_GENERAL_ARCHITECTURE, general_arch, false); // model name to lowercase std::transform(model_name.begin(), model_name.end(), model_name.begin(), @@ -2091,8 +2093,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } ); - // set attributes by model/tokenizer name - if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) { + // set attributes by model/tokenizer/architecture name + if (false + || _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"}) + || _contains_any(general_arch, {"nomic-bert-moe"}) + ) { _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true); } else if (_contains_any(model_name, {"phi-3", "phi3"})) { for (auto id : cache_special_tokens) { diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 1c98079217235..c6d998f101912 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -19,8 +19,8 @@ using json = nlohmann::ordered_json; static std::ostream & operator<<(std::ostream & os, const common_chat_msg_diff & diff) { - // os << "reasoning_content_delta: " << diff.reasoning_content_delta << '\n'; os << "{ content_delta: " << diff.content_delta << "; "; + os << "reasoning_content_delta: " << diff.reasoning_content_delta << "; "; if (diff.tool_call_index != std::string::npos) { os << "tool_call_index: " << diff.tool_call_index << "; "; os << "tool_call_delta.name: " << diff.tool_call_delta.name << "; "; diff --git a/tests/test-gguf.cpp b/tests/test-gguf.cpp index eaf572c666410..3f0c312e2f003 100644 --- a/tests/test-gguf.cpp +++ b/tests/test-gguf.cpp @@ -16,6 +16,7 @@ constexpr int offset_has_data = 3000; enum handcrafted_file_type { HANDCRAFTED_HEADER_BAD_MAGIC = 10, + HANDCRAFTED_HEADER_BAD_VERSION_0 = 15, HANDCRAFTED_HEADER_BAD_VERSION_1 = 20, HANDCRAFTED_HEADER_BAD_VERSION_FUTURE = 30, HANDCRAFTED_HEADER_BAD_N_TENSORS = 40, @@ -51,6 +52,7 @@ enum handcrafted_file_type { static std::string handcrafted_file_type_name(const enum handcrafted_file_type hft) { switch (hft) { case HANDCRAFTED_HEADER_BAD_MAGIC: return "HEADER_BAD_MAGIC"; + case HANDCRAFTED_HEADER_BAD_VERSION_0: return "HEADER_BAD_VERSION_0"; case HANDCRAFTED_HEADER_BAD_VERSION_1: return "HEADER_BAD_VERSION_1"; case HANDCRAFTED_HEADER_BAD_VERSION_FUTURE: return "HEADER_BAD_VERSION_FUTURE"; case HANDCRAFTED_HEADER_BAD_N_KV: return "HEADER_BAD_N_KV"; @@ -171,7 +173,10 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft helper_write(file, GGUF_MAGIC, 4); } - if (hft == HANDCRAFTED_HEADER_BAD_VERSION_1) { + if (hft == HANDCRAFTED_HEADER_BAD_VERSION_0) { + const uint32_t version = 0; + helper_write(file, version); + } else if (hft == HANDCRAFTED_HEADER_BAD_VERSION_1) { const uint32_t version = 1; helper_write(file, version); } else if (hft == HANDCRAFTED_HEADER_BAD_VERSION_FUTURE) { @@ -660,6 +665,7 @@ static std::pair test_handcrafted_file(const unsigned int seed) { const std::vector hfts = { HANDCRAFTED_HEADER_BAD_MAGIC, + HANDCRAFTED_HEADER_BAD_VERSION_0, HANDCRAFTED_HEADER_BAD_VERSION_1, HANDCRAFTED_HEADER_BAD_VERSION_FUTURE, HANDCRAFTED_HEADER_BAD_N_KV, diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 508a64c586de1..40deab5ab00a8 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -70,6 +70,7 @@ struct mtmd_cli_context { llama_model * model; llama_context * lctx; const llama_vocab * vocab; + common_sampler * smpl; llama_batch batch; int n_batch; @@ -89,8 +90,9 @@ struct mtmd_cli_context { model = llama_init.model.get(); lctx = llama_init.context.get(); vocab = llama_model_get_vocab(model); + smpl = common_sampler_init(model, params.sampling); n_threads = params.cpuparams.n_threads; - batch = llama_batch_init(params.n_batch, 0, 1); + batch = llama_batch_init(1, 0, 1); // batch for next token generation n_batch = params.n_batch; if (!model || !lctx) { @@ -118,6 +120,11 @@ struct mtmd_cli_context { } } + ~mtmd_cli_context() { + llama_batch_free(batch); + common_sampler_free(smpl); + } + void init_vision_context(common_params & params) { const char * clip_path = params.mmproj.path.c_str(); mtmd_context_params mparams = mtmd_context_params_default(); @@ -153,7 +160,7 @@ struct mtmd_cli_context { } }; -static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) { +static int generate_response(mtmd_cli_context & ctx, int n_predict) { llama_tokens generated_tokens; for (int i = 0; i < n_predict; i++) { if (i > n_predict || !g_is_generating || g_is_interrupted) { @@ -161,9 +168,9 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int break; } - llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1); + llama_token token_id = common_sampler_sample(ctx.smpl, ctx.lctx, -1); generated_tokens.push_back(token_id); - common_sampler_accept(smpl, token_id, true); + common_sampler_accept(ctx.smpl, token_id, true); if (llama_vocab_is_eog(ctx.vocab, token_id) || ctx.check_antiprompt(generated_tokens)) { LOG("\n"); @@ -261,7 +268,6 @@ int main(int argc, char ** argv) { bool is_single_turn = !params.prompt.empty() && !params.image.empty(); - struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling); int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict; // Ctrl+C handling @@ -300,7 +306,7 @@ int main(int argc, char ** argv) { if (eval_message(ctx, msg, true)) { return 1; } - if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) { + if (!g_is_interrupted && generate_response(ctx, n_predict)) { return 1; } @@ -366,7 +372,7 @@ int main(int argc, char ** argv) { return 1; } if (g_is_interrupted) break; - if (generate_response(ctx, smpl, n_predict)) { + if (generate_response(ctx, n_predict)) { return 1; } content.clear(); diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index 64f03fd1e7eb2..686f42f3960fe 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -311,6 +311,7 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, GGML_ABORT("chunk type not supported"); } + llama_batch_free(text_batch); return 0; } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 4b92eeac9499b..9038df4c3830e 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -360,7 +360,7 @@ struct server_task { params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format; } params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format; - params.oaicompat_chat_syntax.reasoning_in_content = params.stream; + params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (params_base.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); } @@ -2016,6 +2016,11 @@ struct server_context { params_base.n_cache_reuse = 0; SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled"); } + + if (!params_base.speculative.model.path.empty()) { + SRV_ERR("%s\n", "err: speculative decode is not supported by this context"); + return false; + } } return true; @@ -3203,9 +3208,7 @@ struct server_context { } } else { // if we don't cache the prompt, we have to remove the entire KV cache - llama_kv_self_seq_rm(ctx, slot.id, 0, -1); slot.n_past = 0; - slot.cache_tokens.clear(); // TODO: not needed, will be cleared later via "keep_first()" } if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { @@ -3220,7 +3223,6 @@ struct server_context { SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); - llama_kv_self_seq_rm(ctx, slot.id, 0, -1); slot.n_past = 0; } } diff --git a/tools/server/tests/unit/test_tool_call.py b/tools/server/tests/unit/test_tool_call.py index 610610749bd34..20f048c6f6aa5 100755 --- a/tools/server/tests/unit/test_tool_call.py +++ b/tools/server/tests/unit/test_tool_call.py @@ -499,13 +499,12 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr @pytest.mark.slow -@pytest.mark.parametrize("n_predict,reasoning_format,stream,expect_reasoning_content,expect_content,hf_repo,template_override", [ - (128, 'deepseek', CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), - (128, None, CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), - (1024, 'deepseek', CompletionMode.NORMAL, "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - (1024, 'deepseek', CompletionMode.STREAMED, None, "^I need to calculate [\\s\\S]*?To find the sum of [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - (1024, 'deepseek', CompletionMode.NORMAL, "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), - (1024, 'deepseek', CompletionMode.STREAMED, None, "^First, I [\\s\\S]*?To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), +@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) +@pytest.mark.parametrize("n_predict,reasoning_format,expect_reasoning_content,expect_content,hf_repo,template_override", [ + (128, 'deepseek', None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (128, None, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (1024, 'deepseek', "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + (1024, 'deepseek', "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), # (1024, 'none', CompletionMode.NORMAL, None, "^(\\s*)?I need[\\s\\S]*?\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), # (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None), ]) diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index f7e1b3b3b7b8e..bc547ca03bf1b 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -308,10 +308,12 @@ def make_any_request( stream = data.get('stream', False) if stream: content: list[str] = [] + reasoning_content: list[str] = [] tool_calls: list[dict] = [] finish_reason: Optional[str] = None content_parts = 0 + reasoning_content_parts = 0 tool_call_parts = 0 arguments_parts = 0 @@ -322,6 +324,10 @@ def make_any_request( assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!' content.append(choice['delta']['content']) content_parts += 1 + if choice['delta'].get('reasoning_content') is not None: + assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!' + reasoning_content.append(choice['delta']['reasoning_content']) + reasoning_content_parts += 1 if choice['delta'].get('finish_reason') is not None: finish_reason = choice['delta']['finish_reason'] for tc in choice['delta'].get('tool_calls', []): @@ -349,8 +355,10 @@ def make_any_request( tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name'] if fct.get('arguments') is not None: tool_call['function']['arguments'] += fct['arguments'] + arguments_parts += 1 + tool_call_parts += 1 - print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts') + print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts') result = dict( choices=[ dict( @@ -359,6 +367,7 @@ def make_any_request( message=dict( role='assistant', content=''.join(content) if content else None, + reasoning_content=''.join(reasoning_content) if reasoning_content else None, tool_calls=tool_calls if tool_calls else None, ), )