diff --git a/common/chat-parser.h b/common/chat-parser.h index 7c660e539..1e7a3f949 100644 --- a/common/chat-parser.h +++ b/common/chat-parser.h @@ -24,9 +24,9 @@ class common_chat_msg_parser { std::string prelude; std::vector groups; }; - + common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax); - + // Accessors const std::string & input() const { return input_; } size_t pos() const { return pos_; } @@ -42,7 +42,7 @@ class common_chat_msg_parser { } pos_ = pos; } - + void move_back(size_t n) { if (pos_ < n) { throw std::runtime_error("Can't move back that far!"); @@ -56,46 +56,46 @@ class common_chat_msg_parser { // Content manipulation void add_content(const std::string & content); void add_reasoning_content(const std::string & reasoning_content); - + // Tool call manipulation void add_tool_call(const common_chat_tool_call & tool_call); bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments); bool add_tool_call(const json & tool_call); bool add_tool_calls(const json & arr); void clear_tools(); - + // Parsing utilities std::string consume_rest(); bool try_consume_literal(const std::string & literal); void consume_literal(const std::string & literal); bool try_parse_reasoning(const std::string & start_think, const std::string & end_think); - + // Regex-based parsing methods (new) std::optional try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true); find_regex_result consume_regex(const common_regex & regex); std::optional try_consume_regex(const common_regex & regex); - + // Progressive parsing primitives (for Phase 4) std::optional try_find_literal(const std::string & literal); bool consume_spaces(); void set_healing_marker(const std::string & marker); - - + + // Main parsing entry point void parse(); - + // Finishing void finish(); - + // Result extraction common_chat_msg result_and_reset(); - + // Advanced JSON parsing (following original llama.cpp patterns) struct consume_json_result { json value; bool is_partial; }; - + std::optional try_consume_json(); common_json consume_json(); consume_json_result consume_json_with_dumped_args( @@ -112,8 +112,8 @@ class common_chat_msg_parser { void parse_kimi_k2_format(); void parse_deepseek_r1_format(); void parse_generic_format(); - - + + // JSON parsing utilities (enhanced streaming support) struct json_parse_result { json value; @@ -121,11 +121,11 @@ class common_chat_msg_parser { bool is_partial; std::string healing_marker; }; - + // Partial detection utilities bool detect_partial_function_call(const std::string& content); void handle_partial_detection(); - + // Legacy find_literal for compatibility std::optional try_find_literal_legacy(const std::string & literal); }; @@ -133,4 +133,4 @@ class common_chat_msg_parser { // Main parsing function (public API) common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax); -// Content-only parsing for fallback scenarios (static internal function) \ No newline at end of file +// Content-only parsing for fallback scenarios (static internal function) diff --git a/common/chat.cpp b/common/chat.cpp index f62c28011..8be086333 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -220,7 +220,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { // Check for the new tools array format first (no DeepSeek markers) auto original_pos = builder.pos(); - + // First, try the tools array format for content like "function\n```json\n{"tools": [...]}" if (builder.try_find_regex(function_regex_simple)) { builder.move_to(original_pos); @@ -231,7 +231,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { // Fall through to try standard DeepSeek patterns } } - + // If tools array format didn't work, try XML-wrapped format builder.move_to(original_pos); try { @@ -240,7 +240,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { } catch (const common_chat_msg_partial_exception&) { // Fall through to try standard DeepSeek patterns } - + // If XML wrapper format didn't work, try standard DeepSeek patterns builder.move_to(original_pos); try { @@ -278,7 +278,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { throw; // Re-throw for partial mode } } - + // Add any remaining content (critical for responses without tool calls) builder.add_content(builder.consume_rest()); } @@ -286,19 +286,19 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { // Parse DeepSeek R1 tools array format following original llama.cpp parse_prefixed_json_tool_call_array pattern static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) { static const common_regex prefix("function\n```json\n"); - - + + if (auto res = builder.try_find_regex(prefix)) { // Parse JSON and manually process tools array to convert arguments to strings auto json_result = builder.try_consume_json(); if (!json_result) { throw common_chat_msg_partial_exception("invalid JSON"); } - - + + // DeepSeek R1 format has "tools" array, manually process each tool if (json_result->json.contains("tools") && json_result->json.at("tools").is_array()) { - + // Manually create tool calls array with string arguments (following original pattern) json tools_with_dumped_args = json::array(); for (const auto& tool : json_result->json.at("tools")) { @@ -310,15 +310,15 @@ static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) { tools_with_dumped_args.push_back(formatted_tool); } } - - + + if (!builder.add_tool_calls(tools_with_dumped_args) || !json_result->healing_marker.marker.empty()) { throw common_chat_msg_partial_exception("incomplete tool call array"); } } else { throw common_chat_msg_partial_exception("tools key not found or not array"); } - + // Consume closing ``` builder.try_consume_regex(common_regex("```")); } else { @@ -326,41 +326,41 @@ static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) { } } -// Parse DeepSeek R1 XML-wrapped format following original Hermes-2-Pro pattern +// Parse DeepSeek R1 XML-wrapped format following original Hermes-2-Pro pattern static void parse_deepseek_r1_xml_wrapped(common_chat_msg_parser & builder) { - + // Pattern for: \nfunctionFunctionName\n```json\n{...}\n```\n static const common_regex xml_pattern( "\\s*" // Opening XML tag - "function([^\\n]+)" // Function name after "function" + "function([^\\n]+)" // Function name after "function" "\\s*```json\\s*" // JSON block start ); - + if (auto res = builder.try_find_regex(xml_pattern)) { - + // Extract function name from capture group std::string function_name = builder.str(res->groups[1]); - + // Parse JSON arguments auto json_result = builder.try_consume_json(); if (!json_result) { throw common_chat_msg_partial_exception("invalid JSON in XML wrapper"); } - - + + // Create single tool call following original pattern json tool_call; tool_call["name"] = function_name; tool_call["arguments"] = json_result->json.dump(); // Convert to string - + json tool_calls_array = json::array(); tool_calls_array.push_back(tool_call); - - + + if (!builder.add_tool_calls(tool_calls_array) || !json_result->healing_marker.marker.empty()) { throw common_chat_msg_partial_exception("incomplete XML wrapped tool call"); } - + // Consume closing ```\n builder.try_consume_regex(common_regex("```\\s*")); } else { @@ -384,6 +384,15 @@ static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) { builder.add_content(kimi_k2::clean_content(builder.input())); } +static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { + // TODO @ngxson : this won't work with --special enabled, we should fix that + builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>"); + if (!builder.syntax().enable_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } +} + // Main parsing dispatch function static void common_chat_parse(common_chat_msg_parser & builder) { switch (builder.syntax().format) { @@ -399,6 +408,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_KIMI_K2: common_chat_parse_kimi_k2(builder); break; + case COMMON_CHAT_FORMAT_GPT_OSS: + common_chat_parse_gpt_oss(builder); + break; default: throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } @@ -432,6 +444,19 @@ const char* common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_GENERIC: return "generic"; case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "deepseek_r1"; case COMMON_CHAT_FORMAT_KIMI_K2: return "kimi_k2"; + case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS"; default: return "unknown"; } -} \ No newline at end of file +} + +const char * common_reasoning_format_name(common_reasoning_format format) { + switch (format) { + case COMMON_REASONING_FORMAT_NONE: return "none"; + case COMMON_REASONING_FORMAT_AUTO: return "auto"; + case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek"; + case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy"; + default: + throw std::runtime_error("Unknown reasoning format"); + } +} + diff --git a/common/chat.h b/common/chat.h index 5899ef1a1..83e31566d 100644 --- a/common/chat.h +++ b/common/chat.h @@ -13,20 +13,20 @@ struct common_chat_templates; struct common_string_range { size_t begin; size_t end; - + common_string_range(size_t begin, size_t end) : begin(begin), end(end) { if (begin > end) { throw std::runtime_error("Invalid range"); } } - + // prevent default ctor common_string_range() = delete; - + bool empty() const { return begin == end; } - + bool operator==(const common_string_range & other) const { return begin == other.begin && end == other.end; } @@ -40,7 +40,7 @@ struct common_chat_tool_call { bool operator==(const common_chat_tool_call & other) const { return name == other.name && arguments == other.arguments && id == other.id; } - + bool operator!=(const common_chat_tool_call & other) const { return !(*this == other); } @@ -65,10 +65,10 @@ struct common_chat_msg { std::string tool_call_id; bool empty() const { - return content.empty() && content_parts.empty() && tool_calls.empty() && + return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); } - + void ensure_tool_call_ids_set(std::vector & ids_cache, const std::function & gen_tool_call_id) { for (auto i = 0u; i < tool_calls.size(); i++) { if (ids_cache.size() <= i) { @@ -91,7 +91,7 @@ struct common_chat_msg { && tool_name == other.tool_name && tool_call_id == other.tool_call_id; } - + bool operator!=(const common_chat_msg & other) const { return !(*this == other); } @@ -110,7 +110,7 @@ struct common_chat_msg_diff { && tool_call_index == other.tool_call_index && tool_call_delta == other.tool_call_delta; } - + bool operator!=(const common_chat_msg_diff & other) const { return !(*this == other); } @@ -132,18 +132,20 @@ enum common_chat_format { COMMON_CHAT_FORMAT_CONTENT_ONLY, COMMON_CHAT_FORMAT_GENERIC, COMMON_CHAT_FORMAT_DEEPSEEK_R1, + COMMON_CHAT_FORMAT_GPT_OSS, COMMON_CHAT_FORMAT_KIMI_K2, // Our custom format (keep last for backward compatibility) }; enum common_reasoning_format { COMMON_REASONING_FORMAT_NONE, + COMMON_REASONING_FORMAT_AUTO, COMMON_REASONING_FORMAT_DEEPSEEK, COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, }; struct common_chat_syntax { common_chat_format format = COMMON_CHAT_FORMAT_KIMI_K2; - common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO; //COMMON_REASONING_FORMAT_NONE; // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode) bool reasoning_in_content = false; bool thinking_forced_open = false; @@ -165,11 +167,12 @@ class common_chat_msg_partial_exception : public std::runtime_error { // Format detection from chat template common_chat_format common_chat_format_detect(const std::string & chat_template); const char* common_chat_format_name(common_chat_format format); +const char* common_reasoning_format_name(common_reasoning_format format); // Main parsing function (entry point for original llama.cpp compatibility) common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax); -// Forward declare parser class +// Forward declare parser class class common_chat_msg_parser; // Format-specific parsing functions (accessible from chat-parser) diff --git a/examples/sweep-bench/sweep-bench.cpp b/examples/sweep-bench/sweep-bench.cpp index 31dd3ce0e..865f65f06 100644 --- a/examples/sweep-bench/sweep-bench.cpp +++ b/examples/sweep-bench/sweep-bench.cpp @@ -61,7 +61,7 @@ int main(int argc, char ** argv) { const llama_vocab * vocab = llama_get_vocab(ctx); - llama_token bos = llama_token_bos_impl(*vocab); + llama_token bos = vocab->token_bos(); //llama_token eos = llama_token_eos_impl(*vocab); const unsigned int n_vocab = llama_n_vocab(model); diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 0276c171c..ce90a42f6 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -132,7 +132,7 @@ set (GGML_CUDA_MIN_BATCH_OFFLOAD "32" CACHE STRING option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF) option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF) option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF) -option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" OFF) +option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ON) option(GGML_IQK_FLASH_ATTENTION "ggml: enable the IQK FlashAttention CPU kernels" ON) option(GGML_IQK_FA_ALL_QUANTS "ggml: compile all quants for IQK FlashAttention" OFF) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 2c2613928..d6350f6e2 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -325,6 +325,16 @@ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ GGML_TENSOR_LOCALS(size_t, nb, dst, nb) +#define GGML_TENSOR_TERNARY_OP_LOCALS \ + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \ + GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + #define GGML_TENSOR_BINARY_OP_LOCALS01 \ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ @@ -571,6 +581,7 @@ extern "C" { GGML_OP_DUP, GGML_OP_ADD, + GGML_OP_ADD_ID, GGML_OP_ADD1, GGML_OP_ACC, GGML_OP_SUB, @@ -674,6 +685,7 @@ extern "C" { GGML_UNARY_OP_HARDSWISH, GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_SWIGLU, + GGML_UNARY_OP_SWIGLU_OAI, GGML_UNARY_OP_COUNT, }; @@ -1028,6 +1040,13 @@ extern "C" { struct ggml_tensor * b, enum ggml_type type); + // dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]] + GGML_API struct ggml_tensor * ggml_add_id( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * ids); + GGML_API struct ggml_tensor * ggml_add1( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1268,6 +1287,13 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_swiglu_oai( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float alpha, + float limit); + // a - x // b - dy GGML_API struct ggml_tensor * ggml_silu_back( @@ -1370,6 +1396,16 @@ extern "C" { struct ggml_tensor * ids, enum ggml_unary_op op); + GGML_API struct ggml_tensor * ggml_moe_up_gate_ext( + struct ggml_context * ctx, + struct ggml_tensor * a_up, + struct ggml_tensor * a_gate, + struct ggml_tensor * b, + struct ggml_tensor * ids, + struct ggml_tensor * a_up_b, + struct ggml_tensor * a_gate_b, + enum ggml_unary_op op); + // A: m columns, n rows, // B: p columns, n rows, // result is m columns, p rows @@ -1662,6 +1698,11 @@ extern "C" { float scale, float max_bias); + GGML_API void ggml_soft_max_add_sinks( + struct ggml_tensor * a, + struct ggml_tensor * sinks); + + GGML_API struct ggml_tensor * ggml_soft_max_back( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1998,6 +2039,10 @@ extern "C" { struct ggml_tensor * a, enum ggml_prec prec); + GGML_API void ggml_flash_attn_ext_add_sinks( + struct ggml_tensor * a, + struct ggml_tensor * sinks); + // TODO: needs to be adapted to ggml_flash_attn_ext GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 67c1ba18f..5bdfa9420 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -43,6 +43,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) { case GGML_OP_DIAG_MASK_ZERO: case GGML_OP_DIAG_MASK_INF: case GGML_OP_ADD: + case GGML_OP_ADD_ID: case GGML_OP_ADD1: case GGML_OP_SUB: case GGML_OP_MUL: diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 9372c05c1..5bc773a95 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -37,6 +37,8 @@ #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/conv-transpose-1d.cuh" +#include "ggml-cuda/add-id.cuh" +#include "ggml-cuda/graph.cuh" #include #include @@ -49,6 +51,7 @@ #include #include #include +#include #include #include #include @@ -77,6 +80,7 @@ GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, #define GGML_CUDA_LOG_INFO(...) ggml_cuda_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__) #define GGML_CUDA_LOG_WARN(...) ggml_cuda_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__) #define GGML_CUDA_LOG_ERROR(...) ggml_cuda_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define GGML_CUDA_LOG_DEBUG(...) ggml_cuda_log(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) GGML_ATTRIBUTE_FORMAT(2, 3) static void ggml_cuda_log(enum ggml_log_level level, const char * format, ...) { @@ -444,6 +448,35 @@ std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(i return std::unique_ptr(new ggml_cuda_pool_leg(device)); } +static std::mutex ggml_cuda_lock; +static std::condition_variable ggml_cuda_lock_cv; +static std::atomic ggml_cuda_lock_counter; + +ggml_backend_cuda_context::ggml_backend_cuda_context(int device) : + device(device), name(GGML_CUDA_NAME + std::to_string(device)) { +} + +ggml_backend_cuda_context::~ggml_backend_cuda_context() { + + std::unique_lock lock(ggml_cuda_lock); + ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; }); + + if (copy_event != nullptr) { + CUDA_CHECK(cudaEventDestroy(copy_event)); + } + for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) { + for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) { + if (streams[i][j] != nullptr) { + CUDA_CHECK(cudaStreamDestroy(streams[i][j])); + } + } + if (cublas_handles[i] != nullptr) { + CUBLAS_CHECK(cublasDestroy(cublas_handles[i])); + } + } + +} + // cuda buffer struct ggml_backend_cuda_buffer_context { @@ -2220,6 +2253,24 @@ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_origin } } +//static __global__ void k_quick_add(uint32_t n, uint32_t n_per_row, const float * src1, const float * src2, float * dst) { +// +// for (uint32_t j = threadIdx.x; j < n; j += blockDim.x) { +// dst[j] = src1[j] + src2[j % n_per_row]; +// } +//} + +static __global__ void k_quick_add(uint32_t n_per_row, const float * src1, const float * src2, float * dst) { + + uint32_t row = blockIdx.x; + const float * src1_row = src1 + row*n_per_row; + float * dst_row = dst + row*n_per_row; + + for (uint32_t j = threadIdx.x; j < n_per_row; j += blockDim.x) { + dst_row[j] = src1_row[j] + src2[j]; + } +} + static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids, const ggml_tensor * ids, std::vector& moe_counts, std::vector& cum_moe_counts, ggml_cuda_pool_alloc& dev_row_mapping) { @@ -2270,7 +2321,7 @@ static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n return is_ser; } -static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * ids = dst->src[2]; @@ -2319,7 +2370,25 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * 0, src0->ne[1], 1, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); - return; + if (next && next->op == GGML_OP_MUL_MAT_ID && next->src[0]->type == src0->type && src1 == next->src[1] && + ggml_are_same_shape(src0, next->src[0]) && + ggml_backend_buffer_is_cuda(next->src[0]->buffer) && + ggml_backend_buffer_is_cuda(next->buffer) && + !ggml_backend_buffer_is_cuda_split(next->src[0]->buffer)) { + ggml_backend_cuda_buffer_context * next_src0_ctx = (ggml_backend_cuda_buffer_context *) next->src[0]->buffer->context; + ggml_backend_cuda_buffer_context * next_dst_ctx = (ggml_backend_cuda_buffer_context *) next->buffer->context; + if (next_src0_ctx->device == device_id && + next_dst_ctx->device == device_id) { + local_dst.data = next->data; + ggml_cuda_op_mul_mat_vec_q_id(ctx, next->src[0], &local_src1, ids, &local_dst, + (const char *)next->src[0]->data, nullptr, src1_quantized.get(), (float *)next->data, + 0, src0->ne[1], 1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + return true; + } + } + + return false; } } @@ -2356,7 +2425,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst_row.nb[2] = nb1; dst_row.nb[3] = nb1; - if (ne12 == 1) { + if (false && ne12 == 1) { std::vector ids_host(ggml_nbytes(ids)); const char * ids_dev = (const char *) ids->data; CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); @@ -2442,6 +2511,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } } + return false; } static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) { @@ -2470,6 +2540,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor src0_2_ctx->device == device_id && src1_ctx->device == device_id && dst_ctx->device == device_id) { + //printf("%s(%s, %s): %ld x %ld x %ld, %ld x %ld x %ld, %ld x %ld x %ld\n", __func__, src0_1->name, src0_2->name, + // src0->ne[0], src0->ne[1], src0->ne[2], src1->ne[0], src1->ne[1], src1->ne[2], ids->ne[0], ids->ne[1], ids->ne[2]); // Fast TG path const int64_t n_ids = ids->ne[0]; auto stream = ctx.stream(device_id, 0); @@ -2505,12 +2577,26 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor 0, src0_1->ne[1], 1, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); + if (dst->src[4]) { + ggml_cuda_add_id((const float *)local_dst.data, (const float *)dst->src[4]->data, + (const int32_t *)ids->data, (float *)local_dst.data, + local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2], + local_dst.nb[1], local_dst.nb[2], dst->src[4]->nb[1], ids->nb[2], stream); + } + local_dst.data = dst_gate_contiguous.get(); ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_2, &local_src1, ids, &local_dst, (const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_gate_contiguous.get(), 0, src0_2->ne[1], 1, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); + if (dst->src[5]) { + ggml_cuda_add_id((const float *)local_dst.data, (const float *)dst->src[5]->data, + (const int32_t *)ids->data, (float *)local_dst.data, + local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2], + local_dst.nb[1], local_dst.nb[2], dst->src[5]->nb[1], ids->nb[2], stream); + } + if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && ggml_backend_buffer_is_cuda(next->src[0]->buffer) && !ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) && @@ -2518,8 +2604,15 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor ggml_backend_buffer_is_cuda(next->buffer) && ((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id) { - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst->ne[0]*n_ids, - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); + auto unary_op = (ggml_unary_op)dst->op_params[0]; + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get(), dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, unary_op, dst->ne[0]*n_ids, + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get()); + } CUDA_CHECK(cudaGetLastError()); const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING); @@ -2555,8 +2648,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor return true; } else { CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream)); - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst), - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data); + auto unary_op = (ggml_unary_op)dst->op_params[0]; + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst->data, dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, unary_op, ggml_nelements(dst), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data); + } CUDA_CHECK(cudaGetLastError()); return false; } @@ -2624,7 +2723,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor final_src.nb[3] = final_src.nb[2]; } - if (ne12 == 1) { + if (false && ne12 == 1) { ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); if (fuse_down) { @@ -2761,6 +2860,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor } CUDA_CHECK(cudaGetLastError()); + if (dst->src[4]) { + dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); + dim3 grid_dims(num_src1_rows); + k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, + (const float *)((const char *)dst->src[4]->data + i02*dst->src[4]->nb[1]), (float *)dst_row.data); + CUDA_CHECK(cudaGetLastError()); + } + dst_row.data = dst_gate_contiguous.get(); if (use_quantized_src1) { ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, @@ -2770,8 +2877,24 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor } CUDA_CHECK(cudaGetLastError()); - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); + if (dst->src[5]) { + dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); + dim3 grid_dims(num_src1_rows); + k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, + (const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data); + CUDA_CHECK(cudaGetLastError()); + } + + auto unary_op = (ggml_unary_op)dst->op_params[0]; + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0], + 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get()); + } CUDA_CHECK(cudaGetLastError()); if (fuse_down) { @@ -2851,6 +2974,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ADD: ggml_cuda_op_add(ctx, dst); break; + case GGML_OP_ADD_ID: + ggml_cuda_op_add_id(ctx, dst); + break; case GGML_OP_MULTI_ADD: ggml_cuda_op_multi_add(ctx, dst); break; @@ -2877,6 +3003,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_SWIGLU: ggml_cuda_op_swiglu(ctx, dst); break; + case GGML_UNARY_OP_SWIGLU_OAI: + ggml_cuda_op_swiglu_oai(ctx, dst); + break; case GGML_UNARY_OP_GELU_QUICK: ggml_cuda_op_gelu_quick(ctx, dst); break; @@ -2938,7 +3067,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg } break; case GGML_OP_MUL_MAT_ID: - ggml_cuda_mul_mat_id(ctx, dst); + skip_next = ggml_cuda_mul_mat_id(ctx, dst, next); break; case GGML_OP_MOE_FUSED_UP_GATE: skip_next = ggml_cuda_up_gate_unary(ctx, dst, next); @@ -3119,6 +3248,105 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { GGML_UNUSED(backend); } +#ifdef USE_CUDA_GRAPH + +static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, + bool use_cuda_graph) { + + // Loop over nodes in GGML graph to obtain info needed for CUDA graph + cuda_ctx->cuda_graph->cpy_dest_ptrs.clear(); + + const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; + const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj"; + const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased"; + const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased"; + const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased"; + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + + if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) { + use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture +#ifndef NDEBUG + GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__); +#endif + } + + if (node->op == GGML_OP_MUL_MAT_ID && (node->ne[2] != 1 || node->src[2]->ne[0] != 1)) { + use_cuda_graph = false; // This node type is not supported by CUDA graph capture +#ifndef NDEBUG + GGML_CUDA_LOG_DEBUG("%s(%s): disabling CUDA graphs due to unsupported node type %ld %ld\n", + __func__, node->src[0]->name, node->ne[2], node->src[2]->ne[0]); +#endif + } + if (node->op == GGML_OP_MOE_FUSED_UP_GATE) { + auto src0_1 = node->src[0]; + auto src0_2 = node->src[1]; + auto src1 = node->src[2]; + if (src1->ne[1] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || src1->type != GGML_TYPE_F32 || + !ggml_is_quantized(src0_1->type) || !ggml_is_quantized(src0_2->type)) { + use_cuda_graph = false; + } else { + if (i < cgraph->n_nodes-1) { + auto next = cgraph->nodes[i+1]; + if (next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type)) { + ++i; + } + } + } + } + + if (node->op == GGML_OP_ADD && + node->src[1] && node->src[1]->ne[1] > 1 && + (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && + (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) && + strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 && + strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && + strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) { + // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation + // by means of matching node names. See + // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and + // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773, + // Generally, changes in batch size or context size can cause changes to the grid size of some kernels. + use_cuda_graph = false; +#ifndef NDEBUG + GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); +#endif + } + + if (node->op == GGML_OP_CPY) { + + // Store the pointers which are updated for each token, such that these can be sent + // to the device and accessed using indirection from CUDA graph + cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data); + + // store a pointer to each copy op CUDA kernel to identify it later + void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); + if (!ptr) { + use_cuda_graph = false; +#ifndef NDEBUG + GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__); +#endif + } + } + if (!use_cuda_graph) { + break; + } + } + + if (use_cuda_graph) { + cuda_ctx->cuda_graph->use_cpy_indirection = true; + // copy pointers to GPU so they can be accessed via indirection within CUDA graph + ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream()); + } + + return use_cuda_graph; +} + static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { graph_node_properties->node_address = node->data; graph_node_properties->node_op = node->op; @@ -3129,6 +3357,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p for (int i = 0; i < GGML_MAX_SRC; i++) { graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; } + memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS); } static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { @@ -3160,9 +3389,246 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra return false; } } + + if (node->op == GGML_OP_SCALE && + memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { + return false; + } + return true; } +static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { + + bool cuda_graph_update_required = false; + + if (cuda_ctx->cuda_graph->instance == nullptr) { + cuda_graph_update_required = true; + } + + // Check if the graph size has changed + if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { + cuda_graph_update_required = true; + cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); + } + + // Loop over nodes in GGML graph to determine if CUDA graph update is required + // and store properties to allow this comparison for the next token + for (int i = 0; i < cgraph->n_nodes; i++) { + bool has_matching_properties = true; + if (!cuda_graph_update_required) { + has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + } + if (!has_matching_properties) { + cuda_graph_update_required = true; + } + set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + } + + return cuda_graph_update_required; +} + +static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { + +#if CUDART_VERSION >= 12000 + cudaGraphExecUpdateResultInfo result_info; + cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); +#else + cudaGraphNode_t errorNode; + cudaGraphExecUpdateResult result_info; + cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); +#endif // CUDART_VERSION >= 12000 + + if (stat == cudaErrorGraphExecUpdateFailure) { +#ifndef NDEBUG + GGML_CUDA_LOG_DEBUG("%s: CUDA graph update failed\n", __func__); +#endif + + // The pre-existing graph exec cannot be updated due to violated constraints + // so instead clear error and re-instantiate + (void)cudaGetLastError(); + CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); + cuda_ctx->cuda_graph->instance = nullptr; + CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + } else { + GGML_ASSERT(stat == cudaSuccess); + } +} +#endif + +static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, + bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { + // flag used to determine whether it is an integrated_gpu + // TODO + const bool integrated = false; //ggml_cuda_info().devices[cuda_ctx->device].integrated; + + while (!graph_evaluated_or_captured) { + // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. + // With the use of CUDA graphs, the execution will be performed by the graph launch. + if (!use_cuda_graph || cuda_graph_update_required) { + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + ggml_tensor * next = i < cgraph->n_nodes-1 ? cgraph->nodes[i+1] : nullptr; + + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + +#if 0 + static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); + if (!disable_fusion) { + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) { + ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { + i += 2; + ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); + continue; + } + } +#endif +#ifndef NDEBUG + assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j] != nullptr) { + assert(node->src[j]->buffer); + //assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || + // ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft))); + } + } +#else + GGML_UNUSED(integrated); +#endif // NDEBUG + + bool skip_next = false; + bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, skip_next); + if (!ok) { + GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); + if (skip_next) ++i; + } + } +#ifdef USE_CUDA_GRAPH + if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture + if (cuda_ctx->cuda_graph->graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph)); + cuda_ctx->cuda_graph->graph = nullptr; + } + + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); + graph_evaluated_or_captured = true; // CUDA graph has been captured + + std::lock_guard lock(ggml_cuda_lock); + if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) { + ggml_cuda_lock_cv.notify_all(); + } + } else { + graph_evaluated_or_captured = true; // ggml graph has been directly evaluated + } + } + + if (use_cuda_graph) { + if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. + CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + } + if (cuda_graph_update_required) { // Update graph executable + update_cuda_graph_executable(cuda_ctx); + } + // Launch graph + CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); +#else + graph_evaluated_or_captured = true; +#endif // USE_CUDA_GRAPH + } +} + +GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + + ggml_cuda_set_device(cuda_ctx->device); + +#ifdef USE_CUDA_GRAPH + static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); + + // Objects required for CUDA Graph + if (cuda_ctx->cuda_graph == nullptr) { + cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); + } + + bool use_cuda_graph = true; + bool cuda_graph_update_required = false; + + if (cuda_ctx->cuda_graph->graph == nullptr) { + if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) { + cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; +#ifndef NDEBUG + GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); +#endif + } + } + + // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, + // or previous graph capture failure. + // Also disable for multi-gpu for now. TO DO investigate + if (disable_cuda_graphs_due_to_env + || cuda_ctx->cuda_graph->disable_due_to_gpu_arch + || cuda_ctx->cuda_graph->disable_due_to_too_many_updates + || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) { + use_cuda_graph = false; + } + + if (use_cuda_graph) { + cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); + + use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph); + + // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. + if (use_cuda_graph && cuda_graph_update_required) { + cuda_ctx->cuda_graph->number_consecutive_updates++; + } else { + cuda_ctx->cuda_graph->number_consecutive_updates = 0; + } + + if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) { + cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; +#ifndef NDEBUG + GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); +#endif + } + } + + if (use_cuda_graph && cuda_graph_update_required) { + // Start CUDA graph capture + { + std::lock_guard lock(ggml_cuda_lock); + ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed); + } + + CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); + } + + if (!use_cuda_graph) { + cuda_ctx->cuda_graph->use_cpy_indirection = false; + } + +#else + bool use_cuda_graph = false; + bool cuda_graph_update_required = false; +#endif // USE_CUDA_GRAPH + + bool graph_evaluated_or_captured = false; + + evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); + + return GGML_STATUS_SUCCESS; +} + +/* GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; @@ -3431,6 +3897,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t return GGML_STATUS_SUCCESS; } +*/ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; @@ -3440,6 +3907,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SWIGLU: + case GGML_UNARY_OP_SWIGLU_OAI: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSIGMOID: @@ -3629,6 +4097,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: case GGML_OP_ADD: + case GGML_OP_ADD_ID: case GGML_OP_MULTI_ADD: case GGML_OP_MUL: case GGML_OP_DIV: diff --git a/ggml/src/ggml-cuda/add-id.cu b/ggml/src/ggml-cuda/add-id.cu new file mode 100644 index 000000000..34e66954b --- /dev/null +++ b/ggml/src/ggml-cuda/add-id.cu @@ -0,0 +1,72 @@ +#include "add-id.cuh" + +static __global__ void add_id_kernel( + const float * src0, const float * src1, const int32_t * src2, float * dst, + int64_t ne0, int64_t ne1, + size_t nb01, size_t nb02, + size_t nb11, + size_t nb21 + ) { + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.y; + + const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21); + + const size_t nb1 = ne0 * sizeof(float); + const size_t nb2 = ne1 * nb1; + + float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2); + const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02); + const float * src1_row = (const float *)((char *)src1 + i11*nb11); + + for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { + dst_row[i0] = src0_row[i0] + src1_row[i0]; + } +} + +void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + GGML_TENSOR_TERNARY_OP_LOCALS + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_I32); + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb20 == sizeof(int32_t)); + + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + const int32_t * src2_d = (const int32_t *)src2->data; + float * dst_d = (float *)dst->data; + + int threads = std::min((int)ne00, 768); // cols + dim3 blocks(ne01, ne02); // n_experts_used, n_tokens + add_id_kernel<<>>( + src0_d, src1_d, src2_d, dst_d, + ne0, ne1, + nb01, nb02, + nb11, + nb21 + ); +} + +void ggml_cuda_add_id(const float * src0, const float * src1, const int32_t * src2, float * dst, + int64_t ne00, int64_t ne01, int64_t ne02, + int64_t ne0, int64_t ne1, size_t nb01, size_t nb02, size_t nb11, size_t nb21, cudaStream_t stream) { + int threads = std::min((int)ne00, 768); // cols + dim3 blocks(ne01, ne02); // n_experts_used, n_tokens + add_id_kernel<<>>( + src0, src1, src2, dst, + ne0, ne1, + nb01, nb02, + nb11, + nb21 + ); +} diff --git a/ggml/src/ggml-cuda/add-id.cuh b/ggml/src/ggml-cuda/add-id.cuh new file mode 100644 index 000000000..175d6800e --- /dev/null +++ b/ggml/src/ggml-cuda/add-id.cuh @@ -0,0 +1,8 @@ +#include "common.cuh" + +void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_add_id(const float * src0, const float * src1, const int32_t * src2, float * dst, + int64_t ne00, int64_t ne01, int64_t ne02, + int64_t ne0, int64_t ne1, size_t nb01, size_t nb02, size_t nb11, size_t nb21, cudaStream_t stream); + diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index c856a44b6..b24f7fba7 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -108,6 +108,23 @@ static const char * cu_get_error_str(CUresult err) { #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) #endif +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) +# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \ + do { \ + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \ + const int id = ggml_cuda_get_device(); \ + if (!shared_memory_limit_raised[id]) { \ + CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \ + shared_memory_limit_raised[id] = true; \ + } \ + } while (0) +#else +# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \ + do { \ + GGML_UNUSED(nbytes); \ + } while (0) +#endif // !(defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + #if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) #define GGML_CUDA_ASSUME(x) __builtin_assume(x) #else @@ -808,37 +825,7 @@ struct ggml_tensor_extra_gpu { #define USE_CUDA_GRAPH #endif -struct ggml_graph_node_properties { - void * node_address; - ggml_op node_op; - int64_t ne[GGML_MAX_DIMS]; - size_t nb[GGML_MAX_DIMS]; - void * src_address[GGML_MAX_SRC]; -}; - -struct ggml_cuda_graph { -#ifdef USE_CUDA_GRAPH - ~ggml_cuda_graph() { - if (instance != nullptr) { - CUDA_CHECK(cudaGraphExecDestroy(instance)); - } - if (graph != nullptr) { - CUDA_CHECK(cudaGraphDestroy(graph)); - } - } - cudaGraph_t graph = nullptr; - cudaGraphExec_t instance = nullptr; - size_t num_nodes = 0; - std::vector nodes; - std::vector params; - bool disable_due_to_gpu_arch = false; - bool disable_due_to_too_many_updates = false; - bool disable_due_to_failed_graph_capture = false; - int number_consecutive_updates = 0; - std::vector ggml_graph_properties; - std::vector updated_kernel_arg; -#endif -}; +struct ggml_cuda_graph; struct ggml_backend_cuda_context { int device; @@ -850,26 +837,9 @@ struct ggml_backend_cuda_context { std::unique_ptr cuda_graph; - explicit ggml_backend_cuda_context(int device) : - device(device), - name(GGML_CUDA_NAME + std::to_string(device)) { - } + explicit ggml_backend_cuda_context(int device); - ~ggml_backend_cuda_context() { - if (copy_event != nullptr) { - CUDA_CHECK(cudaEventDestroy(copy_event)); - } - for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) { - for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) { - if (streams[i][j] != nullptr) { - CUDA_CHECK(cudaStreamDestroy(streams[i][j])); - } - } - if (cublas_handles[i] != nullptr) { - CUBLAS_CHECK(cublasDestroy(cublas_handles[i])); - } - } - } + ~ggml_backend_cuda_context(); cudaStream_t stream(int device, int stream) { if (streams[device][stream] == nullptr) { diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh new file mode 100644 index 000000000..5f5865c81 --- /dev/null +++ b/ggml/src/ggml-cuda/cpy-utils.cuh @@ -0,0 +1,262 @@ +#pragma once + +#include "ggml-common.h" + +template +static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) { + if constexpr (std::is_same_v) { + *dst = *src; + } else { + *dst = float(*src); + } +} + +static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} + +static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) { + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK4_0; ++j) { + const float v = x[j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + vmax = v; + } + } + + const float d = vmax / -8; + const float id = d ? 1.0f/d : 0.0f; + + y->d = d; + + for (int j = 0; j < QK4_0/2; ++j) { + const float x0 = x[0 + j]*id; + const float x1 = x[QK4_0/2 + j]*id; + + const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f)); + + y->qs[j] = xi0; + y->qs[j] |= xi1 << 4; + } +} + +static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) { + float vmin = FLT_MAX; + float vmax = -FLT_MAX; + + for (int j = 0; j < QK4_1; ++j) { + const float v = x[j]; + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + } + + const float d = (vmax - vmin) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y->dm.x = d; + y->dm.y = vmin; + + for (int j = 0; j < QK4_1/2; ++j) { + const float x0 = (x[0 + j] - vmin)*id; + const float x1 = (x[QK4_1/2 + j] - vmin)*id; + + const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f)); + + y->qs[j] = xi0; + y->qs[j] |= xi1 << 4; + } +} + +static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) { + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK5_0; ++j) { + const float v = x[j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + vmax = v; + } + } + + const float d = vmax / -16; + const float id = d ? 1.0f/d : 0.0f; + + y->d = d; + + uint32_t qh = 0; + for (int j = 0; j < QK5_0/2; ++j) { + const float x0 = x[0 + j]*id; + const float x1 = x[QK5_0/2 + j]*id; + + const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f)); + + y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + memcpy(y->qh, &qh, sizeof(qh)); +} + +static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) { + float min = x[0]; + float max = x[0]; + + for (int j = 1; j < QK5_1; ++j) { + const float v = x[j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = d ? 1.0f/d : 0.0f; + + y->dm.x = d; + y->dm.y = min; + + uint32_t qh = 0; + for (int j = 0; j < QK5_1/2; ++j) { + const float x0 = (x[0 + j] - min)*id; + const float x1 = (x[QK5_1/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); + } + memcpy(y->qh, &qh, sizeof(qh)); +} + +static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = x[j]; + amax = fmaxf(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y->d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = x[j]*id; + y->qs[j] = roundf(x0); + } +} + +static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) { + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK4_NL; ++j) { + const float v = x[j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + vmax = v; + } + } + + float d = vmax / kvalues_iq4nl[0]; + const float id = d ? 1.0f/d : 0.0f; + + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + const float x0 = x[0 + j]*id; + const float x1 = x[QK4_NL/2 + j]*id; + const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0); + const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1); + y->qs[j] = xi0 | (xi1 << 4); + const float v0 = kvalues_iq4nl[xi0]; + const float v1 = kvalues_iq4nl[xi1]; + const float w0 = x[0 + j]*x[0 + j]; + const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j]; + sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + } + + y->d = sumq2 > 0 ? sumqx/sumq2 : d; +} + +static __device__ void quantize_f32_q6_0_block(const float * __restrict__ xi, block_q6_0 * __restrict__ y) { + + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK6_0; ++j) { + const float v = xi[j]; + const float av = fabsf(xi[j]); + if (amax < av) { + amax = av; + vmax = v; + } + } + + const float d = vmax / -32; + const float id = d ? 1.0f/d : 0.0f; + + y->d = d; + memset(y->qh, 0, QK6_0/4); + + for (int j = 0; j < QK6_0/2; ++j) { + const float x0 = xi[0 + j]*id; + const float x1 = xi[QK4_0/2 + j]*id; + + const uint8_t xi0 = min(63, (int8_t)(x0 + 32.5f)); + const uint8_t xi1 = min(63, (int8_t)(x1 + 32.5f)); + + y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2); + y->qh[j%(QK6_0/4)] |= (h << 4*(j/(QK6_0/4))); + } +} + +// Wrapper functions for cpy.cu compatibility +static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { + quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti); +} + +static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) { + quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti); +} + +static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) { + quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti); +} + +static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) { + quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti); +} + +static __device__ void cpy_blck_f32_q6_0(const char * cxi, char * cdsti) { + quantize_f32_q6_0_block((const float *)cxi, (block_q6_0 *)cdsti); +} + +static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { + quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti); +} + +static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { + quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti); +} + +template +static __device__ void cpy_1_flt(const char * cxi, char * cdsti) { + convert_flt((const src_t *)cxi, (dst_t *)cdsti); +} diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 312e0c09a..7fb8d9dbf 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -1,61 +1,26 @@ -// -// Copyright (C) 2023-2024 The ggml authors -// Copyright (C) 2024 Iwan Kawrakow -// MIT license -// SPDX-License-Identifier: MIT -// - #include "cpy.cuh" -#include "convert.cuh" +#include "dequantize.cuh" +#include "graph.cuh" +#include "cpy-utils.cuh" +#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY) +#include "ggml-musa/mudnn.cuh" +#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY typedef void (*cpy_kernel_t)(const char * cx, char * cdst); -static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - float * dsti = (float *) cdsti; - - *dsti = *xi; -} - -static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - half * dsti = (half *) cdsti; - - *dsti = __float2half(*xi); -} - -static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti; - - *dsti = __float2bfloat16(*xi); -} - -static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) { - const half * xi = (const half *) cxi; - half * dsti = (half *) cdsti; - - *dsti = *xi; -} - -static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) { - const half * xi = (const half *) cxi; - float * dsti = (float *) cdsti; - - *dsti = *xi; -} - template -static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { +static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { return; } + char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; + // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor // then combine those indices with the corresponding byte offsets to get the total offsets const int64_t i03 = i/(ne00 * ne01 * ne02); @@ -73,468 +38,329 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } -template -static __global__ void k_cpy_q8_0_to_float(const char * cx, dst_t * dst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb01, const int nb02, const int nb03) { - const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= ne) { - return; - } - - const int64_t i03 = i/(ne00 * ne01 * ne02); - const int64_t i02 = (i - i03*ne00*ne01*ne02) / (ne00*ne01); - const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00; - const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00; - - const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03); - const int ib = i00/QK8_0; - const int iq = i00%QK8_0; - - if constexpr (std::is_same_v) { - dst[i00 + i01*ne00 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __float2bfloat16(__half2float(q8[ib].d)*q8[ib].qs[iq]); - } else { - dst[i00 + i01*ne00 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __half2float(q8[ib].d)*q8[ib].qs[iq]; - } -} - -static __global__ void k_transpose_q8_0(const char * cx, char * cdst, - const int ne10, const int ne11, const int ne12, - const int nb01, const int nb02, const int nb03, - const int nb11, const int nb12, const int nb13) { - const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; - - const int64_t i13 = i/(ne10 * ne11 * ne12); - const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); - const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; - const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; - - //const int64_t ne00 = ne11; - //const int64_t ne01 = ne10; - //const int64_t ne02 = ne12; - const int64_t i03 = i13; - const int64_t i02 = i12; - const int64_t i01 = i10; //(i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00; - const int64_t i00 = i11; //i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00; - - const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03); - const int ib0 = i00/QK8_0; - const int iq0 = i00%QK8_0; - - float xi = __half2float(q8[ib0].d)*q8[ib0].qs[iq0]; - float amax = fabsf(xi); - amax = warp_reduce_max(amax); - - //printf("%d, %d, %d: i = %ld, i11 = %ld i10 = %ld, xi = %g, amax = %g\n", blockDim.x, blockIdx.x, threadIdx.x, i, i11, i10, xi, amax); - - float d = amax/127; - int8_t q = amax == 0.0f ? 0 : roundf(xi / d); - - block_q8_0 * dst = (block_q8_0 *)(cdst + i11*nb11 + i12*nb12 + i13*nb13); - dst[i10 / QK8_0].qs[i10 % QK8_0] = q; - - if (threadIdx.x == 0) { - dst[i10 / QK8_0].d = __float2half(d); - } -} - -static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q8_0 * dsti = (block_q8_0 *) cdsti; - - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - const float v = xi[j]; - amax = fmaxf(amax, fabsf(v)); - } - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dsti->d = d; - - for (int j = 0; j < QK8_0; ++j) { - const float x0 = xi[j]*id; - - dsti->qs[j] = roundf(x0); - } -} - -static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q4_0 * dsti = (block_q4_0 *) cdsti; - - float amax = 0.0f; - float vmax = 0.0f; - - for (int j = 0; j < QK4_0; ++j) { - const float v = xi[j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - vmax = v; - } - } - - const float d = vmax / -8; - const float id = d ? 1.0f/d : 0.0f; - - dsti->d = d; - - for (int j = 0; j < QK4_0/2; ++j) { - const float x0 = xi[0 + j]*id; - const float x1 = xi[QK4_0/2 + j]*id; - - const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f)); - const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f)); - - dsti->qs[j] = xi0; - dsti->qs[j] |= xi1 << 4; - } -} - -static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q4_1 * dsti = (block_q4_1 *) cdsti; - - float vmin = FLT_MAX; - float vmax = -FLT_MAX; - - for (int j = 0; j < QK4_1; ++j) { - const float v = xi[j]; - - if (v < vmin) vmin = v; - if (v > vmax) vmax = v; - } - - const float d = (vmax - vmin) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dsti->dm.x = d; - dsti->dm.y = vmin; - - for (int j = 0; j < QK4_1/2; ++j) { - const float x0 = (xi[0 + j] - vmin)*id; - const float x1 = (xi[QK4_1/2 + j] - vmin)*id; - - const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f)); - const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f)); +static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { + float * cdstf = (float *)(cdsti); - dsti->qs[j] = xi0; - dsti->qs[j] |= xi1 << 4; +#pragma unroll + for (int j = 0; j < QK8_0; j += 2) { + dfloat2 dq; + dequantize_q8_0(cxi, 0, j, dq); + *(cdstf + j) = dq.x; + *(cdstf + j + 1) = dq.y; } } -static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q5_0 * dsti = (block_q5_0 *) cdsti; - - float amax = 0.0f; - float vmax = 0.0f; - - for (int j = 0; j < QK5_0; ++j) { - const float v = xi[j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - vmax = v; - } - } - - const float d = vmax / -16; - const float id = d ? 1.0f/d : 0.0f; - - dsti->d = d; - - uint32_t qh = 0; - for (int j = 0; j < QK5_0/2; ++j) { - const float x0 = xi[0 + j]*id; - const float x1 = xi[QK5_0/2 + j]*id; - - const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f)); - const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f)); +static __device__ void cpy_blck_q8_0_f16(const char * cxi, char * cdsti) { + half * dsth = (half *)(cdsti); - dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); +#pragma unroll + for (int j = 0; j < QK8_0; j += 2) { + dfloat2 dq; + dequantize_q8_0(cxi, 0, j, dq); + *(dsth + j + 0) = __float2half(dq.x); + *(dsth + j + 1) = __float2half(dq.y); } - memcpy(dsti->qh, &qh, sizeof(qh)); } -static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q5_1 * dsti = (block_q5_1 *) cdsti; - - float min = xi[0]; - float max = xi[0]; - - for (int j = 1; j < QK5_1; ++j) { - const float v = xi[j]; - min = v < min ? v : min; - max = v > max ? v : max; - } - - const float d = (max - min) / 31; - const float id = d ? 1.0f/d : 0.0f; +template +static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { + float * cdstf = (float *)(cdsti); - dsti->dm.x = d; - dsti->dm.y = min; - - uint32_t qh = 0; - for (int j = 0; j < QK5_1/2; ++j) { - const float x0 = (xi[0 + j] - min)*id; - const float x1 = (xi[QK5_1/2 + j] - min)*id; - - const uint8_t xi0 = (uint8_t)(x0 + 0.5f); - const uint8_t xi1 = (uint8_t)(x1 + 0.5f); - - dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); +#pragma unroll + for (int j = 0; j < qk/2; j++) { + dfloat2 dq; + dequant(cxi, 0, j, dq); + *(cdstf + j) = dq.x; + *(cdstf + j + qk/2) = dq.y; } - memcpy(dsti->qh, &qh, sizeof(qh)); } -static __device__ void cpy_blck_f32_q6_0(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q6_0 * dsti = (block_q6_0 *) cdsti; - - float amax = 0.0f; - float vmax = 0.0f; - - for (int j = 0; j < QK6_0; ++j) { - const float v = xi[j]; - const float av = fabsf(xi[j]); - if (amax < av) { - amax = av; - vmax = v; - } - } - - const float d = vmax / -32; - const float id = d ? 1.0f/d : 0.0f; - - dsti->d = d; - memset(dsti->qh, 0, QK6_0/4); - - for (int j = 0; j < QK6_0/2; ++j) { - const float x0 = xi[0 + j]*id; - const float x1 = xi[QK4_0/2 + j]*id; - - const uint8_t xi0 = min(63, (int8_t)(x0 + 32.5f)); - const uint8_t xi1 = min(63, (int8_t)(x1 + 32.5f)); +template +static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { + const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; - dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); - const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2); - dsti->qh[j%(QK6_0/4)] |= (h << 4*(j/(QK6_0/4))); + if (i >= ne) { + return; } -} - -static __device__ const int8_t iq4nl_index[241] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 17, 17, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 18, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - 3, 3, 3, 3, 3, 3, 19, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 20, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, - 5, 5, 21, 21, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 22, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 23, 23, 8, 8, 8, 8, - 8, 8, 8, 8, 8, 8, 24, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 25, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 26, 26, - 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 27, 27, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 28, 13, 13, 13, - 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 29, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, - 14, 14, 14, 14, 30, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15 -}; -static __device__ __forceinline__ int best_index_iq4nl(const int8_t * values, float x) { - int ix = (int)x - values[0]; - if (ix < 0 || ix >= 241) return ix < 0 ? 0 : 15; - ix = iq4nl_index[ix]; - return ix < 16 ? ix : x - values[ix-16] < values[ix-15] - x ? ix-16 : ix-15; -} -static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_iq4_nl * dsti = (block_iq4_nl *) cdsti; + char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; - float amax = 0.0f; - float vmax = 0.0f; - - for (int j = 0; j < QK4_NL; ++j) { - const float v = xi[j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - vmax = v; - } - } + const int i03 = i/(ne00 * ne01 * ne02); + const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; - float d = vmax / kvalues_iq4nl[0]; - const float id = d ? 1.0f/d : 0.0f; - - //dsti->d = d; - - float sumqx = 0, sumq2 = 0; - for (int j = 0; j < QK4_NL/2; ++j) { - const float x0 = xi[0 + j]*id; - const float x1 = xi[QK4_NL/2 + j]*id; - const uint8_t xi0 = best_index_iq4nl(kvalues_iq4nl, x0); - const uint8_t xi1 = best_index_iq4nl(kvalues_iq4nl, x1); - dsti->qs[j] = xi0 | (xi1 << 4); - const float v0 = kvalues_iq4nl[xi0]; - const float v1 = kvalues_iq4nl[xi1]; - const float w0 = xi[0 + j]*xi[0 + j]; - const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j]; - sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j]; - sumq2 += w0*v0*v0 + w1*v1*v1; - } + const int i13 = i/(ne10 * ne11 * ne12); + const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; - dsti->d = sumq2 > 0 ? sumqx/sumq2 : d; + cpy_blck(cx + x_offset, cdst + dst_offset); } template -static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, +static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { + const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } + char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; + const int i03 = i/(ne00 * ne01 * ne02); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; - const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; + const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; const int i13 = i/(ne10 * ne11 * ne12); const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; - const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; cpy_blck(cx + x_offset, cdst + dst_offset); } -static void ggml_cpy_f16_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { +// Copy destination pointers to GPU to be available when pointer indirection is in use - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); +void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) { +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) + if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers + CUDA_CHECK(cudaStreamSynchronize(stream)); + if (cuda_graph->dest_ptrs_d != nullptr) { + CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d)); + } + CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *))); + cuda_graph->dest_ptrs_size = host_dest_ptrs_size; + } + // copy destination pointers to GPU + CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream)); + cuda_graph->graph_cpynode_index = 0; // reset index +#else + GGML_UNUSED(cuda_graph); GGML_UNUSED(host_dest_ptrs); + GGML_UNUSED(host_dest_ptrs_size); GGML_UNUSED(stream); +#endif } -static void ggml_cpy_f32_f32_cuda( +template +static void ggml_cpy_flt_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + cpy_flt><<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } -static void ggml_cpy_f32_f16_cuda( +static void ggml_cpy_f32_q8_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + GGML_ASSERT(ne % QK8_0 == 0); + const int num_blocks = ne / QK8_0; + cpy_f32_q<<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } -static void ggml_cpy_f32_bf16_cuda( +static void ggml_cpy_q8_0_f32_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + const int num_blocks = ne; + cpy_q_f32<<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } -static void ggml_cpy_f32_q8_0_cuda( +static void ggml_cpy_q8_0_f16_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { - GGML_ASSERT(ne % QK8_0 == 0); - const int num_blocks = ne / QK8_0; - cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + const int num_blocks = ne; + cpy_q_f32<<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_q4_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK4_0 == 0); const int num_blocks = ne / QK4_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); +} + +static void ggml_cpy_q4_0_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int num_blocks = ne; + cpy_q_f32, QK4_0><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_q4_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK4_1 == 0); const int num_blocks = ne / QK4_1; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); +} + +static void ggml_cpy_q4_1_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int num_blocks = ne; + cpy_q_f32, QK4_1><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_q5_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK5_0 == 0); const int num_blocks = ne / QK5_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); +} + +static void ggml_cpy_q5_0_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int num_blocks = ne; + cpy_q_f32, QK5_0><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_q5_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK5_1 == 0); const int num_blocks = ne / QK5_1; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } -static void ggml_cpy_f32_q6_0_cuda( +static void ggml_cpy_q5_1_f32_cuda( const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { - - GGML_ASSERT(ne % QK6_0 == 0); - const int num_blocks = ne / QK6_0; - cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int num_blocks = ne; + cpy_q_f32, QK5_1><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_iq4_nl_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK4_NL == 0); const int num_blocks = ne / QK4_NL; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } -static void ggml_cpy_f16_f16_cuda( +static void ggml_cpy_f32_q6_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + GGML_ASSERT(ne % QK6_0 == 0); + const int num_blocks = ne / QK5_1; + cpy_f32_q<<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); +} + +static void ggml_cpy_q6_0_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int num_blocks = ne; + cpy_q_f32, QK6_0><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); +} + +static __global__ void k_transpose_q8_0(const char * cx, char * cdst, + const int ne10, const int ne11, const int ne12, + const int nb01, const int nb02, const int nb03, + const int nb11, const int nb12, const int nb13) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + + //const int64_t ne00 = ne11; + //const int64_t ne01 = ne10; + //const int64_t ne02 = ne12; + const int64_t i03 = i13; + const int64_t i02 = i12; + const int64_t i01 = i10; //(i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00; + const int64_t i00 = i11; //i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00; + + const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03); + const int ib0 = i00/QK8_0; + const int iq0 = i00%QK8_0; + + float xi = __half2float(q8[ib0].d)*q8[ib0].qs[iq0]; + float amax = fabsf(xi); + amax = warp_reduce_max(amax); + + float d = amax/127; + int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + + block_q8_0 * dst = (block_q8_0 *)(cdst + i11*nb11 + i12*nb12 + i13*nb13); + dst[i10 / QK8_0].qs[i10 % QK8_0] = q; + + if (threadIdx.x == 0) { + dst[i10 / QK8_0].d = __float2half(d); + } } static void transpose_q8_0(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) { @@ -546,36 +372,13 @@ static void transpose_q8_0(ggml_backend_cuda_context & ctx, const ggml_tensor * dst->nb[1], dst->nb[2], dst->nb[3]); } -static void copy_q8_0_to_float(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) { - auto stream = ctx.stream(); - auto num_blocks = ggml_nelements(dst)/QK8_0; - if (dst->type == GGML_TYPE_F16) { - k_cpy_q8_0_to_float<<>>((const char *)src->data, (half *)dst->data, ggml_nelements(dst), - src->ne[0], src->ne[1], src->ne[2], src->nb[1], src->nb[2], src->nb[3]); - } - else if (dst->type == GGML_TYPE_F32) { - k_cpy_q8_0_to_float<<>>((const char *)src->data, (float *)dst->data, ggml_nelements(dst), - src->ne[0], src->ne[1], src->ne[2], src->nb[1], src->nb[2], src->nb[3]); - } - else if (dst->type == GGML_TYPE_BF16) { - k_cpy_q8_0_to_float<<>>((const char *)src->data, (nv_bfloat16 *)dst->data, ggml_nelements(dst), - src->ne[0], src->ne[1], src->ne[2], src->nb[1], src->nb[2], src->nb[3]); - } - else { - GGML_ABORT("fatal error"); - } -} -void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) { +void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); - //if (ggml_nbytes(src0) > INT_MAX) { - // printf("%s: %s has %zu bytes\n", __func__, src0->name, ggml_nbytes(src0)); - //} - // These asserts appear to be unnecessary. Why were they added? - //GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); - //GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); + GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); + GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; @@ -604,120 +407,151 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char * src0_ddc = (char *) src0->data; char * src1_ddc = (char *) src1->data; - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + char ** dest_ptrs_d = nullptr; + int graph_cpynode_index = -1; +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) + if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { + dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d; + graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index; + } +#else + GGML_UNUSED(disable_indirection_for_this_node); +#endif + if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); +#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY) + if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) { + CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0)); + } else +#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY + { + CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + } + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_f32_bf16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { - ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F16) { + ggml_cpy_q8_0_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { - ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { - ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { - ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) { - ggml_cpy_f32_q6_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { - ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { - ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) { + ggml_cpy_f32_q6_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_Q6_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q6_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); - } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); - } else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && - (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)) { - copy_q8_0_to_float(ctx, src0, src1); - } else if (ggml_is_contiguous(src0) && ggml_are_same_shape(src0, src1)) { - if (src1->type == GGML_TYPE_F16) { - auto to_fp16 = ggml_get_to_fp16_cuda(src0->type); - if (to_fp16) { - to_fp16(src0->data, (half *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream); - } - } - else if (src1->type == GGML_TYPE_F32) { - auto to_fp32 = ggml_get_to_fp32_cuda(src0->type); - if (to_fp32) { - to_fp32(src0->data, (float *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream); - } - } - else if (src1->type == GGML_TYPE_BF16) { - auto to_bf16 = ggml_get_to_bf16_cuda(src0->type); - if (to_bf16) { - to_bf16(src0->data, (nv_bfloat16 *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream); - } - } + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) { + // This is needed for MLA with mla=2 when using q8_0 cache. transpose_q8_0(ctx, src0, src1); } else { - fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, + GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); - fprintf(stderr, "%s: %ld x %ld x %ld; %zu x %zu %zu -> %ld x %ld x %ld; %zu x %zu x %zu\n", __func__, - src0->ne[0], src0->ne[1], src0->ne[2], src0->nb[1], src0->nb[2], src0->nb[3], - src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[1], src1->nb[2], src1->nb[3]); - GGML_ABORT("fatal error"); } +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) + if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { + ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index; + } +#else + GGML_UNUSED(disable_indirection_for_this_node); +#endif + } void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - ggml_cuda_cpy(ctx, src0, dst); + bool disable_indirection = true; + ggml_cuda_cpy(ctx, src0, dst, disable_indirection); } void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_f32_f16; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_f32_f16; + if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + return nullptr; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { - return (void*) cpy_f32_f16; + return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { + return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { - return (void*) cpy_f32_q; + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32; + } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F16) { + return (void*) cpy_q_f32; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { - return (void*) cpy_f32_q; + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32, QK4_0>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { - return (void*) cpy_f32_q; + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32, QK4_1>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { - return (void*) cpy_f32_q; + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32, QK5_0>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { - return (void*) cpy_f32_q; + return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { - return (void*) cpy_f32_q; + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32, QK5_1>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) { - return (void*) cpy_f32_q; + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q6_0 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32, QK5_1>; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_f32_f16; + return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { + return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_f32_f16; - } else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && - (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)) { - return (void*)copy_q8_0_to_float; - } else if (ggml_is_contiguous(src0) && ggml_are_same_shape(src0, src1)) { - if (src1->type == GGML_TYPE_F16) { - auto to_fp16 = ggml_get_to_fp16_cuda(src0->type); - if (to_fp16) return (void*)to_fp16; - } - else if (src1->type == GGML_TYPE_F32) { - auto to_fp32 = ggml_get_to_fp32_cuda(src0->type); - if (to_fp32) return (void*)to_fp32; - } - else if (src1->type == GGML_TYPE_BF16) { - auto to_bf16 = ggml_get_to_bf16_cuda(src0->type); - if (to_bf16) return (void*)to_bf16; - } - } - else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) { + return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { + return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { + return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_flt>; + } else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) { return (void *)transpose_q8_0; + } else { + GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, + ggml_type_name(src0->type), ggml_type_name(src1->type)); } - fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, - ggml_type_name(src0->type), ggml_type_name(src1->type)); - fprintf(stderr, "%s: %ld x %ld x %ld; %zu x %zu %zu -> %ld x %ld x %ld; %zu x %zu x %zu\n", __func__, - src0->ne[0], src0->ne[1], src0->ne[2], src0->nb[1], src0->nb[2], src0->nb[3], - src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[1], src1->nb[2], src1->nb[3]); - GGML_ABORT("fatal error"); } diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh index 796167426..0bd3c0c6f 100644 --- a/ggml/src/ggml-cuda/cpy.cuh +++ b/ggml/src/ggml-cuda/cpy.cuh @@ -1,9 +1,11 @@ #include "common.cuh" -#define CUDA_CPY_BLOCK_SIZE 32 +#define CUDA_CPY_BLOCK_SIZE 64 -void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1); +void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false); void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1); + +void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/dequantize.cuh b/ggml/src/ggml-cuda/dequantize.cuh index bd3c2d9db..bd76dc2f8 100644 --- a/ggml/src/ggml-cuda/dequantize.cuh +++ b/ggml/src/ggml-cuda/dequantize.cuh @@ -86,6 +86,24 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in #endif // GGML_CUDA_F16 } +static __device__ __forceinline__ void dequantize_q6_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ + const block_q6_0 * x = (const block_q6_0 *) vx; + + const dfloat d = x[ib].d; + + const uint8_t h = x[ib].qh[iqs%8] >> 2*(iqs/8); + v.x = ((x[ib].qs[iqs] & 0xf) | ((h & 0x3) << 4)); + v.y = ((x[ib].qs[iqs] >> 4) | ((h & 0xc) << 2)); + +#ifdef GGML_CUDA_F16 + v = __hsub2(v, {32.0f, 32.0f}); + v = __hmul2(v, {d, d}); +#else + v.x = (v.x - 32.0f) * d; + v.y = (v.y - 32.0f) * d; +#endif // GGML_CUDA_F16 +} + static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const block_q8_0 * x = (const block_q8_0 *) vx; diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 5e8ee0f63..8e7fadba2 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -22,6 +22,7 @@ typedef void (* fattn_kernel_t)( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -747,6 +748,7 @@ void launch_fattn( const ggml_tensor * V = dst->src[2]; const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; ggml_tensor * KQV = dst; @@ -837,6 +839,7 @@ void launch_fattn( K_data, V_data, mask ? ((const char *) mask->data) : nullptr, + sinks ? ((const char *) sinks->data) : nullptr, (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, scale, max_bias, m0, m1, softcap, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], @@ -1008,7 +1011,8 @@ void launch_fattn_mma( const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; - const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; ggml_tensor * KQV = dst; @@ -1162,6 +1166,7 @@ void launch_fattn_mma( K_data, V_data, mask ? ((const char *) mask->data) : nullptr, + sinks ? ((const char *)sinks->data) : nullptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index af38071d2..14444832c 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -425,6 +425,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, const half2 * const __restrict__ mask_h2, + const float * const __restrict__ sinks_f, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, @@ -584,6 +585,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } + // If attention sinks are used, potentially re-scale if KQ_max is small. + // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum + // so it's being done unconditionally for every thread. + if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) { + float KQ_max_scale[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented"); + const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col); + const float sink = sinks_f[jc % ncols2]; + + const float KQ_max_new = fmaxf(KQ_max[col], sink); + const float KQ_max_diff = KQ_max[col] - KQ_max_new; + KQ_max_scale[col] = expf(KQ_max_diff); + KQ_max[col] = KQ_max_new; + + *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD; + + const float KQ_max_add = expf(sink - KQ_max_new); + KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add; + } + + if (ntiles == 1) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); +#pragma unroll + for (int i = 0; i < D/tile_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { + VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + } + } + } + } + } + // Write VKQ accumulators to shared memory in column-major format. // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // Also for np > 1 the combination is done via these values in shared memory. @@ -823,6 +870,7 @@ static __global__ void flash_attn_mma_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -896,6 +944,7 @@ static __global__ void flash_attn_mma_ext_f16( const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + const float * sinks_f = sinks ? (const float *) sinks + channel * ncols2 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; @@ -906,12 +955,12 @@ static __global__ void flash_attn_mma_ext_f16( if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } @@ -934,6 +983,7 @@ static __global__ void flash_attn_mma_ext_f16( const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; @@ -943,10 +993,10 @@ static __global__ void flash_attn_mma_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 6178b3e5a..c37a618ff 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -43,37 +43,37 @@ struct fattn_mma_f16_config; // Perhaps the 256 head size needs a closer look // to see if this implementation is better. // -//template <> -//struct fattn_mma_f16_config< 64, 64> { -// static constexpr int nbatch_fa = 64; -// static constexpr int nwarps_max = 4; -// static constexpr bool Q_in_reg = true; -// static constexpr int nstages_target = 2; -// -// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { -// return 32; -// } -// -// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { -// return 32; -// } -// -// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { -// return 32; -// } -// -// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { -// return 32; -// } -// -// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { -// return 32; -// } -// -// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { -// return 32; -// } -//}; +template <> +struct fattn_mma_f16_config< 64, 64> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 32; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 32; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 32; + } +}; // //template <> //struct fattn_mma_f16_config< 80, 80> { @@ -493,7 +493,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V); } else { constexpr bool use_cp_async = nstages == 1; - if constexpr (ncols2 > 1 || mask_h2) { + if (ncols2 > 1 || mask_h2) { flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); } } @@ -576,7 +576,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( float KQ_rowsum_add[cols_per_thread] = {0.0f}; if constexpr (ntiles == 1) { - if constexpr (ncols2 > 1 || mask_h2) { + if (ncols2 > 1 || mask_h2) { #pragma unroll for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) { const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; @@ -818,6 +818,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, const half2 * const __restrict__ mask_h2, + const float * const __restrict__ sinks_f, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, @@ -975,6 +976,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); } + // If attention sinks are used, potentially re-scale if KQ_max is small. + // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum + // so it's being done unconditionally for every thread. + if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) { + float KQ_max_scale[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented"); + const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col); + const float sink = sinks_f[jc % ncols2]; + + const float KQ_max_new = fmaxf(KQ_max[col], sink); + const float KQ_max_diff = KQ_max[col] - KQ_max_new; + KQ_max_scale[col] = expf(KQ_max_diff); + KQ_max[col] = KQ_max_new; + + *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD; + + const float KQ_max_add = expf(sink - KQ_max_new); + KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add; + } + + if (ntiles == 1) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); +#pragma unroll + for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { + VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + } + } + } + } + } + // Finally, sum up partial KQ rowsums. // The partial sums are spread across 8/4 threads each, does not need full reduce. { @@ -1222,7 +1269,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } #else - GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); + GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); GGML_UNUSED(sinks_f); GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); @@ -1239,6 +1286,7 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -1323,6 +1371,7 @@ static __global__ void flash_attn_ext_f16( const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr; const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); @@ -1335,12 +1384,12 @@ static __global__ void flash_attn_ext_f16( if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } @@ -1362,6 +1411,7 @@ static __global__ void flash_attn_ext_f16( const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr; const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); @@ -1373,7 +1423,7 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); @@ -1535,7 +1585,8 @@ static void launch_fattn_new_mma( const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; - const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; ggml_tensor * KQV = dst; @@ -1709,6 +1760,7 @@ static void launch_fattn_new_mma( K_data, V_data, mask ? ((const char *) mask->data) : nullptr, + sinks ? ((const char *)sinks->data) : nullptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, logit_softcap, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], @@ -1853,6 +1905,11 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; + if (use_gqa_opt && gqa_ratio % 16 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + if (use_gqa_opt && gqa_ratio % 8 == 0) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; @@ -1878,8 +1935,6 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens const ggml_tensor * V = dst->src[2]; const ggml_tensor * mask = dst->src[3]; - GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512); - float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); @@ -1888,6 +1943,12 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; + + if (K->ne[0] == 64 && V->ne[0] == 64) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst); + return; + } + GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512); GGML_ASSERT(gqa_ratio % 16 == 0); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 420f0bb08..c79b4821d 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index f525f1bbf..3d1926ce4 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f32( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 2cf4f4ef9..95dd0e963 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -71,6 +72,7 @@ static __global__ void flash_attn_vec_ext_f16( V += nb22*(blockIdx.y / gqa_ratio); const half * maskh = (const half *) mask + ne11*ic0; + const float * sinksf = (const float *) (sinks); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); @@ -270,6 +272,39 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); } + if (sinksf) { + const half sink = __float2half(sinksf[blockIdx.y]); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.x == 0) { + kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + half kqmax_new_j = kqmax_shared[j][threadIdx.y]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; + + const half val = hexp(sink - kqmax[j]); + kqsum[j] = kqsum[j]*KQ_max_scale; + + if (tid == 0) { + kqsum[j] += val; + } + + VKQ[j] *= __half2half2(KQ_max_scale); + } + + __syncthreads(); + } + #pragma unroll for (int j = 0; j < ncols; ++j) { kqsum[j] = warp_reduce_sum(kqsum[j]); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index c91cef3dc..a97b3737d 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f32( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -69,6 +70,7 @@ static __global__ void flash_attn_vec_ext_f32( K += nb12*(blockIdx.y / gqa_ratio); V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; + const float * sinksf = (const float *) (sinks); const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); @@ -254,6 +256,39 @@ static __global__ void flash_attn_vec_ext_f32( __syncthreads(); } + if (sinksf) { + const float sink = sinksf[blockIdx.y]; + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.x == 0) { + kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + float kqmax_new_j = kqmax_shared[j][threadIdx.y]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; + + const float val = expf(sink - kqmax[j]); + kqsum[j] = kqsum[j]*KQ_max_scale; + + if (tid == 0) { + kqsum[j] += val; + } + + VKQ[j] *= KQ_max_scale; + } + + __syncthreads(); + } + #pragma unroll for (int j = 0; j < ncols; ++j) { kqsum[j] = warp_reduce_sum(kqsum[j]); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index d39c6a6e5..e1e2ec6ed 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -5,6 +5,8 @@ // SPDX-License-Identifier: MIT // +// TODO: attention sinks !!! + #include "common.cuh" #include "fattn-common.cuh" @@ -22,6 +24,7 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -93,6 +96,7 @@ static __global__ void flash_attn_ext_f16( const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); + const float * sinks_f = sinks ? (const float *)sinks + blockIdx.y : nullptr; const int stride_Q = nb01 / sizeof(float); const int stride_K = nb11 / sizeof(half); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 725b443dd..ffcaf219b 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -539,7 +539,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - // As mentioned above, the new new MMA is slower than then the new MMA. + // As mentioned above, the new-new MMA is slower then the new MMA. ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); //ggml_cuda_flash_attn_ext_mma_new(ctx, dst); } diff --git a/ggml/src/ggml-cuda/graph.cuh b/ggml/src/ggml-cuda/graph.cuh new file mode 100644 index 000000000..ed032aa5e --- /dev/null +++ b/ggml/src/ggml-cuda/graph.cuh @@ -0,0 +1,41 @@ +#pragma once + +struct ggml_graph_node_properties { + void * node_address; + ggml_op node_op; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + void * src_address[GGML_MAX_SRC]; + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; +}; + +struct ggml_cuda_graph { +#ifdef USE_CUDA_GRAPH + ~ggml_cuda_graph() { + if (instance != nullptr) { + CUDA_CHECK(cudaGraphExecDestroy(instance)); + } + if (graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(graph)); + } + } + cudaGraph_t graph = nullptr; + cudaGraphExec_t instance = nullptr; + size_t num_nodes = 0; + std::vector nodes; + std::vector params; + bool disable_due_to_gpu_arch = false; + bool disable_due_to_too_many_updates = false; + bool disable_due_to_failed_graph_capture = false; + int number_consecutive_updates = 0; + std::vector ggml_graph_properties; + bool use_cpy_indirection = false; + std::vector cpy_dest_ptrs; + char ** dest_ptrs_d; + int dest_ptrs_size = 0; + // Index to allow each cpy kernel to be aware of it's position within the graph + // relative to other cpy nodes. + int graph_cpynode_index = -1; +#endif +}; + diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index c006301fa..4c9aa7edf 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -19,7 +19,7 @@ __device__ float __forceinline__ t2f32(half val) { } template -static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, float cap_params0, float cap_params1, bool do_softcap) { +static __global__ void soft_max_f32_nosinks(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, float cap_params0, float cap_params1, bool do_softcap) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -124,7 +124,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst } template -static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) { +static void soft_max_f32_cuda_nosinks(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -142,39 +142,40 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { switch (ncols_x) { case 32: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 64: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 128: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 256: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 512: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 1024: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 2048: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 4096: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; default: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; } } else { const size_t shmem_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); } } +#if 0 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -205,13 +206,14 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (use_f16) { const half * src1_dd = (const half *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); + soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); } else { const float * src1_dd = (const float *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); + soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); } } +#endif void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; @@ -241,10 +243,283 @@ void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * ds if (use_f16) { const half * src1_dd = (const half *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); + soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); } else { const float * src1_dd = (const float *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); + soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); } } + +struct soft_max_params { + + int64_t nheads; + uint32_t n_head_log2; + int64_t ncols; + int64_t nrows_x; + int64_t nrows_y; + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + int64_t nb11; + int64_t nb12; + int64_t nb13; + + int64_t ne12; + int64_t ne13; + float scale; + float max_bias; + float m0; + float m1; +}; + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template +static __global__ void soft_max_f32( + const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) { + const int ncols = ncols_template == 0 ? p.ncols : ncols_template; + + const int tid = threadIdx.x; + + const int64_t i03 = blockIdx.z; + const int64_t i02 = blockIdx.y; + const int64_t i01 = blockIdx.x; + + //TODO: noncontigous inputs/outputs + const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y; + + const int64_t i11 = i01; + const int64_t i12 = i02 % p.ne12; + const int64_t i13 = i03 % p.ne13; + + x += int64_t(rowx)*ncols; + mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr); + dst += int64_t(rowx)*ncols; + + const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + + const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1); + + extern __shared__ float data_soft_max_f32[]; + float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication + // shared memory buffer to cache values between iterations: + float * vals = use_shared ? buf_iw + WARP_SIZE : dst; + + float max_val = sinks ? sinks[i02] : -INFINITY; + +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + break; + } + + const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f); + + vals[col] = val; + max_val = max(max_val, val); + } + + // find the max value in the block + max_val = warp_reduce_max(max_val); + if (block_size > WARP_SIZE) { + if (warp_id == 0) { + buf_iw[lane_id] = -INFINITY; + } + __syncthreads(); + + if (lane_id == 0) { + buf_iw[warp_id] = max_val; + } + __syncthreads(); + + max_val = buf_iw[lane_id]; + max_val = warp_reduce_max(max_val); + } + + float tmp = 0.0f; // partial sum +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + break; + } + + const float val = expf(vals[col] - max_val); + tmp += val; + vals[col] = val; + } + + // find the sum of exps in the block + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __syncthreads(); + if (warp_id == 0) { + buf_iw[lane_id] = 0.0f; + } + __syncthreads(); + + if (lane_id == 0) { + buf_iw[warp_id] = tmp; + } + __syncthreads(); + + tmp = buf_iw[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + if (sinks) { + tmp += expf(sinks[i02] - max_val); + } + + const float inv_sum = 1.0f / tmp; + +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + return; + } + + dst[col] = vals[col] * inv_sum; + } +} +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ + +template +static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst, + const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared) +{ + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + + auto launch_kernel = [=](auto I) -> bool { + constexpr int ncols = decltype(I)::value; + constexpr int block = (ncols > 1024 ? 1024 : ncols); + + if (p.ncols == ncols) { + CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32), smpbo); + soft_max_f32<<>> + (x, mask, sinks, dst, p); + return true; + } + return false; + }; + + // unary fold over launch_kernel + if ((launch_kernel(std::integral_constant{}) || ...)) { + return; + } + + //default case + CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32), smpbo); + soft_max_f32<<>>(x, mask, sinks, dst, p); +} + +template +static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) { + int nth = WARP_SIZE; + const int64_t ncols_x = params.ncols; + + while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; + const dim3 block_dims(nth, 1, 1); + const dim3 block_nums(params.ne01, params.ne02, params.ne03); + const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); + static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); + + + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + + + if (nbytes_shared <= smpbo) { + launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared); + } else { + const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); + soft_max_f32<<>>(x, mask, sinks, dst, params); + } +} + +void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + const float * src0_d = (const float *) src0->data; + const void * src1_d = src1 ? (const void *) src1->data : nullptr; + const void * src2_d = src2 ? (const void *) src2->data : nullptr; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; + + const int64_t ne00 = src0->ne[0]; + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; + + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + + soft_max_params params = {}; + params.nheads = src0->ne[2]; + params.n_head_log2 = n_head_log2; + params.ncols = ne00; + params.nrows_x = nrows_x; + params.nrows_y = nrows_y; + params.ne00 = src0->ne[0]; + params.ne01 = src0->ne[1]; + params.ne02 = src0->ne[2]; + params.ne03 = src0->ne[3]; + params.nb11 = nb11; + params.nb12 = nb12; + params.nb13 = nb13; + params.ne12 = ne12; + params.ne13 = ne13; + params.scale = scale; + params.max_bias = max_bias; + params.m0 = m0; + params.m1 = m1; + + if (use_f16) { + soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream); + } else { + soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream); + } +} + diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 6312f25c5..a6742b980 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -470,3 +470,83 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } + +template +static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) { + const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + // perform base op and multiply with gate (either offset in same tensor or a separate one) + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + + float xi = x[j0]; + float gi = g[j1]; + xi = fminf(xi, limit); + gi = fmaxf(fminf(gi, limit), -limit); + + float out_glu = xi / (1.0f + expf(-xi * alpha)); + out_glu = out_glu * (1.0f + gi); + + dst[i] = out_glu; +} + +template +static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) { + const int64_t num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; + swiglu_oai_kernel<<>>(x, g, dst, k, n, o0, o1, alpha, limit); +} + +void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + void * src0_d = src0->data; + void * src1_d = src1 ? src1->data : src0->data; + const int64_t src0_o = src0->nb[1]; + const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + void * dst_d = dst->data; + const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(src0->nb[0] == ggml_element_size(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == dst->type); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src1->nb[0] == ggml_element_size(src1)); + GGML_ASSERT(src1->ne[0] == nc); + GGML_ASSERT(src0->type == src1->type); + } + + //const int32_t swapped = ((const int32_t *) dst->op_params)[1]; + const int32_t swapped = false; //ggml_get_op_params_i32(dst, 1); + const float * op_params = (const float *)dst->op_params; + const float alpha = op_params[2]; + const float limit = op_params[3]; + + float * src0_p = (float *) src0_d; + float * src1_p = (float *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, + src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream); +} + +void ggml_swiglu_oai_cuda_f32(const float * x, const float * g, float * dst, const int64_t k, const int64_t n, + const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) { + swiglu_oai_cuda(x, g, dst, k, n, o0, o1, alpha, limit, stream); +} diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 9bcd30a84..9da1f8ca2 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -47,3 +47,9 @@ void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op, int64_t nelements, const float * x, const float * y, float * z); void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_swiglu_oai_cuda_f32(const float * x, const float * g, float * dst, const int64_t k, const int64_t n, + const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream); + diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f3a23727b..695dc7227 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2823,7 +2823,6 @@ inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } @@ -2834,6 +2833,19 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } +inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { + int i = 0; +#if defined(__AVX2__) + for (; i + 7 < n; i += 8) { + __m256 vx = _mm256_loadu_ps(x + i); + __m256 vy = _mm256_loadu_ps(y + i); + __m256 vz = _mm256_add_ps(vx, vy); + _mm256_storeu_ps(z + i, vz); + } +#endif + for (; i < n; ++i) z[i] = x[i] + y[i]; +} + static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -4004,6 +4016,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "DUP", "ADD", + "ADD_ID", "ADD1", "ACC", "SUB", @@ -4092,13 +4105,14 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); +static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", "x", "x+y", + "x[i]+y", "x+y", "view(x,nb,offset)+=y->x", "x-y", @@ -4187,7 +4201,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); +static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4207,9 +4221,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "HARDSWISH", "HARDSIGMOID", "SWIGLU", + "SWIGLU_OAI", }; -static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14"); +static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -5917,6 +5932,29 @@ struct ggml_tensor * ggml_add_cast( return ggml_add_cast_impl(ctx, a, b, type); } +// ggml_add_id + +struct ggml_tensor * ggml_add_id( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * ids) { + + GGML_ASSERT(a->ne[0] == b->ne[0]); + GGML_ASSERT(a->ne[1] == ids->ne[0]); + GGML_ASSERT(a->ne[2] == ids->ne[1]); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_ADD_ID; + result->src[0] = a; + result->src[1] = b; + result->src[2] = ids; + + return result; +} + // ggml_add1 static struct ggml_tensor * ggml_add1_impl( @@ -6662,6 +6700,36 @@ struct ggml_tensor * ggml_swiglu( return result; } +struct ggml_tensor * ggml_swiglu_oai( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float alpha, + float limit) { + + GGML_ASSERT(ggml_is_contiguous_1(a)); + if (b) { + GGML_ASSERT(ggml_is_contiguous_1(b)); + GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(a->type == b->type); + } + + int64_t ne[4] = {a->ne[0]/2, a->ne[1], a->ne[2], a->ne[3]}; + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0); + + result->op = GGML_OP_UNARY; + result->grad = NULL; + result->src[0] = a; + result->src[1] = b; + + ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU_OAI); + ggml_set_op_params_f32(result, 2, alpha); + ggml_set_op_params_f32(result, 3, limit); + + return result; +} + // ggml_silu_back struct ggml_tensor * ggml_silu_back( @@ -7017,6 +7085,66 @@ struct ggml_tensor * ggml_moe_up_gate( result->src[1] = as_gate; result->src[2] = b; result->src[3] = ids; + result->src[4] = NULL; + result->src[5] = NULL; + + ggml_set_op_params_i32(result, 0, (int32_t) op); + + return result; +} + +struct ggml_tensor * ggml_moe_up_gate_ext( + struct ggml_context * ctx, + struct ggml_tensor * as_up, + struct ggml_tensor * as_gate, + struct ggml_tensor * b, + struct ggml_tensor * ids, + struct ggml_tensor * as_up_b, + struct ggml_tensor * as_gate_b, + enum ggml_unary_op op) { + + if (!as_up_b && !as_gate_b) { + return ggml_moe_up_gate(ctx, as_up, as_gate, b, ids, op); + } + + if (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate)) { + struct ggml_tensor * result_up = ggml_mul_mat_id(ctx, as_up, b, ids); + if (as_up_b) { + result_up = ggml_add_id(ctx, result_up, as_up_b, ids); + } + struct ggml_tensor * result_gate = ggml_mul_mat_id(ctx, as_gate, b, ids); + if (as_gate_b) { + result_gate = ggml_add_id(ctx, result_gate, as_gate_b, ids); + } + return ggml_fused_mul_unary(ctx, result_gate, result_up, op); + } + + GGML_ASSERT(!ggml_is_transposed(as_up)); + GGML_ASSERT(!ggml_is_transposed(as_gate)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + GGML_ASSERT(as_up->ne[3] == 1); // as is 3d (one matrix per expert) + GGML_ASSERT(b->ne[3] == 1); // b is 3d + GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d + GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row + GGML_ASSERT(as_up->ne[0] == b->ne[0]); // can_mul_mat + GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast + + GGML_ASSERT(as_up->ne[1] == as_up_b->ne[0]); + GGML_ASSERT(as_gate->ne[1] == as_gate_b->ne[0]); + bool is_node = false; + + const int64_t ne[4] = { as_up->ne[1], ids->ne[0], b->ne[2], 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_MOE_FUSED_UP_GATE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = as_up; + result->src[1] = as_gate; + result->src[2] = b; + result->src[3] = ids; + result->src[4] = as_up_b; + result->src[5] = as_gate_b; ggml_set_op_params_i32(result, 0, (int32_t) op); @@ -7970,6 +8098,22 @@ struct ggml_tensor * ggml_soft_max_ext( return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); } +void ggml_soft_max_add_sinks( + struct ggml_tensor * a, + struct ggml_tensor * sinks) { + if (!sinks) { + a->src[2] = NULL; + return; + } + + GGML_ASSERT(a->op == GGML_OP_SOFT_MAX); + GGML_ASSERT(a->src[2] == NULL); + GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]); + GGML_ASSERT(sinks->type == GGML_TYPE_F32); + + a->src[2] = sinks; +} + // ggml_soft_max_back static struct ggml_tensor * ggml_soft_max_back_impl( @@ -8833,6 +8977,22 @@ void ggml_flash_attn_ext_set_prec( ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second } +void ggml_flash_attn_ext_add_sinks( + struct ggml_tensor * a, + struct ggml_tensor * sinks) { + if (!sinks) { + a->src[4] = NULL; + return; + } + + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + GGML_ASSERT(a->src[4] == NULL); + GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]); + GGML_ASSERT(sinks->type == GGML_TYPE_F32); + + a->src[4] = sinks; +} + // ggml_flash_attn_back struct ggml_tensor * ggml_flash_attn_back( @@ -11497,6 +11657,77 @@ static void ggml_compute_forward_multi_add( } } +// ggml_compute_forward_add_id + +static void ggml_compute_forward_add_id_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src2 = dst->src[2]; + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_I32); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_TERNARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + // src1 indices + const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21); + + GGML_ASSERT(i11 >= 0 && i11 < ne11); + + ggml_vec_add_f32(ne0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), + (float *) ((char *) src1->data + i11*nb11)); + } +} + +static void ggml_compute_forward_add_id( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_add_id_f32(params, dst); + } break; + default: + { + GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type)); + } + } +} + // ggml_compute_forward_add1 static void ggml_compute_forward_add1_f32( @@ -13760,6 +13991,93 @@ static void ggml_compute_forward_swiglu( } } +// ggml_compute_forward_swiglu_oai + +static void ggml_compute_forward_swiglu_oai_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = false; //ggml_get_op_params_i32(dst, 1); + const float alpha = ggml_get_op_params_f32(dst, 2); + const float limit = ggml_get_op_params_f32(dst, 3); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1])); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + for (int k = 0; k < nc; k++) { + const float x = MIN(src0_p[k], limit); + const float y = MAX(MIN(src1_p[k], limit), -limit); + const float out_glu = x / (1.f + expf(alpha * (-x))); + dst_p[k] = out_glu * (y + 1.f); + } + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = dst_p[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_swiglu_oai( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_swiglu_oai_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_fused_mul_unary static void ggml_compute_forward_fused_mul_unary_f32( @@ -15167,6 +15485,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate( const struct ggml_tensor * src1 = dst->src[2]; const struct ggml_tensor * ids = dst->src[3]; + const struct ggml_tensor * up_b = dst->src[4]; + const struct ggml_tensor * gate_b = dst->src[5]; const struct ggml_tensor * src0_1 = dst->src[0]; const struct ggml_tensor * src0_2 = dst->src[1]; const struct ggml_tensor * src0 = src0_1; // so GGML_TENSOR_BINARY_OP_LOCALS works @@ -15191,6 +15511,9 @@ static void ggml_compute_forward_mul_mat_id_up_gate( GGML_ASSERT(nb2 <= nb3); GGML_ASSERT(ne13 == 1); + const size_t nb41 = up_b ? up_b->nb[1] : 0; + const size_t nb51 = up_b ? gate_b->nb[1] : 0; + // row groups const int n_ids = ids->ne[0]; // n_expert_used const int n_as = ne02; // n_expert @@ -15278,6 +15601,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate( const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02; const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02; + const char * up_b_cur = up_b ? (const char *)up_b->data + cur_a*nb41 : NULL; + const char * gate_b_cur = gate_b ? (const char *)gate_b->data + cur_a*nb51 : NULL; const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); @@ -15288,6 +15613,7 @@ static void ggml_compute_forward_mul_mat_id_up_gate( if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0], type, src0_1_cur, src0_2_cur, nb01, vec_dot_type, (const char *)wdata, row_size, + up_b_cur, gate_b_cur, (float *)dst->data, nb1, nb2, matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); @@ -16645,6 +16971,7 @@ static void ggml_compute_forward_soft_max_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src2 = dst->src[2]; assert(ggml_is_contiguous(dst)); assert(ggml_are_same_shape(src0, dst)); @@ -16662,6 +16989,13 @@ static void ggml_compute_forward_soft_max_f32( GGML_TENSOR_UNARY_OP_LOCALS + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; + + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; + //const int64_t ne11 = src1 ? src1->ne[1] : 1; // TODO: is this supposed to be ceil instead of floor? @@ -16673,67 +17007,80 @@ static void ggml_compute_forward_soft_max_f32( const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - for (int i1 = ir0; i1 < ir1; i1++) { - // ALiBi - const uint32_t h = (i1/ne01)%ne02; // head - const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + // sinks + const float * sk = src2 ? (float *)((char *) src2->data) : NULL; - float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); - float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); - - // broadcast the mask across rows - ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const int64_t i11 = i01; + const int64_t i12 = i02%ne12; + const int64_t i13 = i03%ne13; + + // ALiBi + const uint32_t h = i02; // head + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + + float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + // broadcast the mask across rows + ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL; + + ggml_vec_cpy_f32 (ne00, wp, sp); + ggml_vec_scale_f32(ne00, wp, scale); + if (mp_f32) { + if (use_f16) { + for (int i = 0; i < ne00; ++i) { + wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); + } + } else { + for (int i = 0; i < ne00; ++i) { + wp[i] += slope*mp_f32[i]; + } + } + } - ggml_vec_cpy_f32 (nc, wp, sp); - ggml_vec_scale_f32(nc, wp, scale); - if (mp_f32) { - if (use_f16) { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); +#ifndef NDEBUG + for (int i = 0; i < ne00; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(wp[i])); } - } else { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*mp_f32[i]; +#endif + + float max = -INFINITY; + ggml_vec_max_f32(ne00, &max, wp); + + // if we have sinks, make a correction as if they were included in the softmax + if (sk) { + max = MAX(max, sk[i02]); } - } - } -//#ifndef NDEBUG -// for (int i = 0; i < nc; ++i) { -// //printf("p[%d] = %f\n", i, p[i]); -// assert(!isnan(wp[i])); -// } -//#endif + ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max); + assert(sum > 0.0); - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, wp); + if (sk) { + sum += (ggml_float) expf(sk[i02] - max); + } - ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); - //assert(sum > 0.0); + sum = 1.0/sum; + ggml_vec_scale_f32(ne00, dp, sum); - sum = 1.0/sum; - ggml_vec_scale_f32(nc, dp, sum); +#ifndef NDEBUG + for (int i = 0; i < ne00; ++i) { + assert(!isnan(dp[i])); + assert(!isinf(dp[i])); + } +#endif -//#ifndef NDEBUG -// for (int i = 0; i < nc; ++i) { -// assert(!isnan(dp[i])); -// assert(!isinf(dp[i])); -// } -//#endif + } + } } } @@ -16755,7 +17102,6 @@ static void ggml_compute_forward_soft_max( } } - // ggml_compute_forward_soft_max_back static void ggml_compute_forward_soft_max_back_f32( @@ -18308,12 +18654,14 @@ static void ggml_compute_forward_argsort_thresh( static void ggml_compute_forward_flash_attn_ext_f16( const struct ggml_compute_params * params, - const struct ggml_tensor * q, - const struct ggml_tensor * k, - const struct ggml_tensor * v, - const struct ggml_tensor * mask, struct ggml_tensor * dst) { + const struct ggml_tensor * q = dst->src[0]; + const struct ggml_tensor * k = dst->src[1]; + const struct ggml_tensor * v = dst->src[2]; + const struct ggml_tensor * mask = dst->src[3]; + const struct ggml_tensor * sinks = dst->src[4]; + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) GGML_TENSOR_LOCALS(size_t, nbq, q, nb) GGML_TENSOR_LOCALS(int64_t, nek, k, ne) @@ -18383,6 +18731,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( } #if GGML_USE_IQK_MULMAT + // For now we do not implement sinks in the iqk FA implementation if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias, q->ne[3], q->ne[2], q->nb[3], q->nb[2], k->ne[3], k->ne[2], k->nb[3], k->nb[2], @@ -18390,7 +18739,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( dst->ne[2], dst->ne[1], dst->nb[1], k->type, v->type, Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], - q->data, k->data, v->data, mask->data, + q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL, scale, softcap, (float *)dst->data, params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return; @@ -18447,6 +18796,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; ggml_to_float_t const v_to_float = type_traits[v->type].to_float; + GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type"); + GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type"); + const int64_t Dkv = MAX(Dk, Dv); // loop over n_batch and n_head @@ -18552,6 +18904,22 @@ static void ggml_compute_forward_flash_attn_ext_f16( } } + if (sinks) { + const float s = ((float *)((char *) sinks->data))[h]; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + ms = expf(M - s); + ggml_vec_scale_f32(Dv, VKQ32, ms); + } else { + vs = expf(s - M); + } + + S = S*ms + vs; + } + // V /= S const float S_inv = 1.0f/S; ggml_vec_scale_f32(Dv, VKQ32, S_inv); @@ -18571,17 +18939,13 @@ static void ggml_compute_forward_flash_attn_ext_f16( static void ggml_compute_forward_flash_attn_ext( const struct ggml_compute_params * params, - const struct ggml_tensor * q, - const struct ggml_tensor * k, - const struct ggml_tensor * v, - const struct ggml_tensor * mask, struct ggml_tensor * dst) { switch (dst->op_params[3]) { case GGML_PREC_DEFAULT: case GGML_PREC_F32: { // uses F32 accumulators - ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); + ggml_compute_forward_flash_attn_ext_f16(params, dst); } break; default: { @@ -19350,6 +19714,10 @@ static void ggml_compute_forward_unary( { ggml_compute_forward_swiglu(params, dst); } break; + case GGML_UNARY_OP_SWIGLU_OAI: + { + ggml_compute_forward_swiglu_oai(params, dst); + } break; case GGML_UNARY_OP_HARDSWISH: { ggml_compute_forward_hardswish(params, dst); @@ -19898,6 +20266,10 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_add(params, tensor); } break; + case GGML_OP_ADD_ID: + { + ggml_compute_forward_add_id(params, tensor); + } break; case GGML_OP_ADD1: { ggml_compute_forward_add1(params, tensor); @@ -20136,7 +20508,7 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_FLASH_ATTN_EXT: { - ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + ggml_compute_forward_flash_attn_ext(params, tensor); } break; case GGML_OP_FLASH_ATTN_BACK: { @@ -20486,6 +20858,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table); } } break; + case GGML_OP_ADD_ID: + { + GGML_ABORT("fatal error"); // TODO: implement + } break; case GGML_OP_ADD1: { if (src0->grad) { @@ -21719,6 +22095,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_DUP: case GGML_OP_CONT: case GGML_OP_ADD: + case GGML_OP_ADD_ID: case GGML_OP_ADD1: case GGML_OP_ACC: case GGML_OP_MULTI_ADD: @@ -21758,6 +22135,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SWIGLU: + case GGML_UNARY_OP_SWIGLU_OAI: { n_tasks = n_threads; } break; @@ -21952,6 +22330,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa } } break; case GGML_OP_ADD: + case GGML_OP_ADD_ID: case GGML_OP_ADD1: { if (ggml_is_quantized(node->src[0]->type)) { diff --git a/ggml/src/iqk/fa/iqk_fa_128_128.cpp b/ggml/src/iqk/fa/iqk_fa_128_128.cpp index 52eb289da..e5db1720f 100644 --- a/ggml/src/iqk/fa/iqk_fa_128_128.cpp +++ b/ggml/src/iqk/fa/iqk_fa_128_128.cpp @@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_128_128) { if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types if (nk%64 == 0) { iqk_flash_helper_T<128, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } iqk_flash_helper_T<128, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } #endif if (nk%128 == 0) { return iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } if (nk%64 == 0) { return iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } return iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } diff --git a/ggml/src/iqk/fa/iqk_fa_192_128.cpp b/ggml/src/iqk/fa/iqk_fa_192_128.cpp index 6c4c51fb6..9cd62afd0 100644 --- a/ggml/src/iqk/fa/iqk_fa_192_128.cpp +++ b/ggml/src/iqk/fa/iqk_fa_192_128.cpp @@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_192_128) { if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types if (nk%64 == 0) { iqk_flash_helper_T<192, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } iqk_flash_helper_T<192, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } #endif if (nk%128 == 0) { return iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } if (nk%64 == 0) { return iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } return iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } diff --git a/ggml/src/iqk/fa/iqk_fa_256_256.cpp b/ggml/src/iqk/fa/iqk_fa_256_256.cpp index b0bc35e34..a8565de2c 100644 --- a/ggml/src/iqk/fa/iqk_fa_256_256.cpp +++ b/ggml/src/iqk/fa/iqk_fa_256_256.cpp @@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_256_256) { if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types if (nk%64 == 0) { iqk_flash_helper_T<256, 256, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } iqk_flash_helper_T<256, 256, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } #endif if (nk%128 == 0) { return iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } if (nk%64 == 0) { return iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } return iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } diff --git a/ggml/src/iqk/fa/iqk_fa_576_512.cpp b/ggml/src/iqk/fa/iqk_fa_576_512.cpp index 5174be30a..9517eaa15 100644 --- a/ggml/src/iqk/fa/iqk_fa_576_512.cpp +++ b/ggml/src/iqk/fa/iqk_fa_576_512.cpp @@ -9,7 +9,8 @@ namespace { template inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, - const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { + const float * q, const char * mask, float scale, float softcap, float * qkv, + const float * sinkf, float * M, float * S) { auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) { nq1 -= n; if (nq1 == 0) return true; @@ -21,29 +22,29 @@ inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, }; if (nq1 >= 16) { int n_step = nq1/16; - FlashAttn<576, 512, 16, step_k> fa(scale, softcap); + FlashAttn<576, 512, 16, step_k> fa(scale, softcap, sinkf); fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); if (update(16*n_step)) return; } if (nq1 >= 8) { int n_step = nq1/8; - FlashAttn<576, 512, 8, step_k> fa(scale, softcap); + FlashAttn<576, 512, 8, step_k> fa(scale, softcap, sinkf); fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); if (update(8*n_step)) return; } if (nq1 >= 4) { int n_step = nq1/4; - FlashAttn<576, 512, 4, step_k> fa(scale, softcap); + FlashAttn<576, 512, 4, step_k> fa(scale, softcap, sinkf); fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); if (update(4*n_step)) return; } if (nq1 >= 2) { int n_step = nq1/2; - FlashAttn<576, 512, 2, step_k> fa(scale, softcap); + FlashAttn<576, 512, 2, step_k> fa(scale, softcap, sinkf); fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); if (update(2*n_step)) return; } - FlashAttn<576, 512, 1, step_k> fa(scale, softcap); + FlashAttn<576, 512, 1, step_k> fa(scale, softcap, sinkf); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); } @@ -51,37 +52,37 @@ template inline bool iqk_deepseek_helper(ggml_type type_k, int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, - float scale, float softcap, float * qkv, float * M, float * S) { + float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) { if (type_k == GGML_TYPE_Q8_0) { HelperQ80 kh((const char *)k, stride_k); HelperQ80 vh((const char *)v, stride_v); - iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); return true; } if (type_k == GGML_TYPE_Q8_0_R8) { HelperQ80R8<576> kh((const char *)k, stride_k); HelperQ80 vh((const char *)v, stride_v); - iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); return true; } if (type_k == GGML_TYPE_Q6_0) { HelperQ60 kh((const char *)k, stride_k); HelperQ60 vh((const char *)v, stride_v); - iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); return true; } #if GGML_IQK_FA_ALL_QUANTS if (type_k == GGML_TYPE_Q8_KV) { HelperQ8KV<576> kh((const char *)k, stride_k); HelperQ8KV<512> vh((const char *)v, stride_v); - iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); return true; } #endif if (type_k == GGML_TYPE_F16) { HelperF16 kh((const char *)k, stride_k); HelperF16 vh((const char *)v, stride_v); - iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); return true; } #ifdef __AVX512BF16__ @@ -89,10 +90,10 @@ inline bool iqk_deepseek_helper(ggml_type type_k, HelperBF16<576, step_k> kh((const char *)k, stride_k); HelperBF16<512, step_k> vh((const char *)v, stride_v); if (nq1 % 8 == 0) { - FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap); + FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap, sinkf); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } else { - FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap); + FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap, sinkf); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } return true; @@ -113,7 +114,7 @@ IQK_FA_CASE(iqk_fa_576_512) { } stride_q /= sizeof(float); // q stride as float return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S); + q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, sinkf, M, S); } diff --git a/ggml/src/iqk/fa/iqk_fa_64_64.cpp b/ggml/src/iqk/fa/iqk_fa_64_64.cpp index 652f682ba..84b9bad02 100644 --- a/ggml/src/iqk/fa/iqk_fa_64_64.cpp +++ b/ggml/src/iqk/fa/iqk_fa_64_64.cpp @@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_64_64) { if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types if (nk%64 == 0) { iqk_flash_helper_T<64, 64, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } iqk_flash_helper_T<64, 64, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } #endif if (nk%128 == 0) { return iqk_flash_helper_T<64, 64, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } if (nk%64 == 0) { return iqk_flash_helper_T<64, 64, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } return iqk_flash_helper_T<64, 64, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } diff --git a/ggml/src/iqk/fa/iqk_fa_96_96.cpp b/ggml/src/iqk/fa/iqk_fa_96_96.cpp index fed49cb04..44544f8ba 100644 --- a/ggml/src/iqk/fa/iqk_fa_96_96.cpp +++ b/ggml/src/iqk/fa/iqk_fa_96_96.cpp @@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_96_96) { if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types if (nk%64 == 0) { iqk_flash_helper_T<96, 96, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } iqk_flash_helper_T<96, 96, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } #endif if (nk%128 == 0) { return iqk_flash_helper_T<96, 96, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } if (nk%64 == 0) { return iqk_flash_helper_T<96, 96, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } return iqk_flash_helper_T<96, 96, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index 6de2acea3..1971c4729 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -1141,10 +1141,25 @@ struct FlashQKV { } template - inline void normalize_and_store_1row(const FMS& fms, int j, const qkv_cache_t * R, float * qkv) const { + inline void normalize_and_store_1row(const FMS& fms, int j, qkv_cache_t * R, float * qkv, const float * sinkf) const { static_assert(q_step == FMS::q_step); - GGML_ASSERT(fms.S[j] > 0); - auto norm = F16::set1(1/fms.S[j]); + float S = fms.S[j]; + if (sinkf) { + float s = *sinkf; + if (s > fms.M[j]) { + float m = expf(fms.M[j] - s); + auto vm = F16::set1(m); + for (int i = 0; i < D/F16::block_size; ++i) { + auto Ri = R + F16::block_size*i; + F16::store(Ri, F16::mul(vm, F16::load(Ri))); + } + S = S*m + 1; + } else { + S += expf(s - fms.M[j]); + } + } + GGML_ASSERT(S > 0); + auto norm = F16::set1(1/S); //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); for (int i = 0; i < D/F16::block_size; ++i) { auto r = F16::load(R + F16::block_size*i); @@ -1153,7 +1168,7 @@ struct FlashQKV { } template - inline void normalize_and_store(const FMS& fms, int nq1, int stride_qkv, float * qkv, float * M, float * S) const { + inline void normalize_and_store(const FMS& fms, int nq1, int stride_qkv, float * qkv, const float * sinkf, float * M, float * S) { static_assert(q_step == FMS::q_step); if (M && S) { std::memcpy(M, fms.M, nq1*sizeof(float)); @@ -1173,7 +1188,7 @@ struct FlashQKV { } else { auto R = qkv_cache; for (int j = 0; j < nq1; ++j) { - normalize_and_store_1row(fms, j, R, qkv); + normalize_and_store_1row(fms, j, R, qkv, sinkf); qkv += stride_qkv; R += D; } @@ -1181,7 +1196,7 @@ struct FlashQKV { } template - inline void normalize_and_store(const FMS& fms, int stride_qkv, float * qkv, float * M, float * S) const { + inline void normalize_and_store(const FMS& fms, int stride_qkv, float * qkv, const float * sinkf, float * M, float * S) { static_assert(q_step == FMS::q_step); if (M && S) { std::memcpy(M, fms.M, q_step*sizeof(float)); @@ -1201,7 +1216,7 @@ struct FlashQKV { } else { auto R = qkv_cache; for (int j = 0; j < q_step; ++j) { - normalize_and_store_1row(fms, j, R, qkv); + normalize_and_store_1row(fms, j, R, qkv, sinkf); qkv += stride_qkv; R += D; } @@ -1332,7 +1347,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in FlashMS& fms, FlashQKV& fqkv, const float * q, const char * mask, float * qkv, - float * M, float * S) { + const float * sinkf, float * M, float * S) { #ifdef __aarch64__ float16_t q_f16[Dk*q_step]; #endif @@ -1356,7 +1371,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in vh.next_block(k_step); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); + fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S); q += q_step*stride_q; mask += q_step*stride_m; @@ -1383,7 +1398,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in vh.next_block(k_step); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S); } } @@ -1392,7 +1407,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, FlashMS& fms, FlashQKV& fqkv, const float * q, const char * mask, float * qkv, - float * M, float * S, char * qptr) { + const float * sinkf, float * M, float * S, char * qptr) { auto q8 = (typename KHelper::block_q8 *)qptr; if constexpr (q_step > 1 && std::is_same_v) { if (nq1 == q_step) { @@ -1412,7 +1427,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, vh.next_block(k_step); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); + fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S); return; } } @@ -1449,10 +1464,10 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, } #if FA_TIMING t1 = Perf::cur_time(); - fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); + fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S); perf.accum_nolock(3, t1); #else - fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); + fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S); #endif q += q_step*stride_q; @@ -1474,7 +1489,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, vh.next_block(k_step); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S); } #if FA_TIMING Perf::instance().add(perf); @@ -1504,7 +1519,7 @@ struct FlashAttn { static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); - FlashAttn(float scale, float softcap) : fms(scale, softcap) {} + FlashAttn(float scale, float softcap, const float * sinkf) : fms(scale, softcap), sinkf(sinkf) {} template void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, @@ -1533,7 +1548,7 @@ struct FlashAttn { HelperQ80R8 khr4(nk1, kh); #endif compute_helper_q, VHelper, FlashQKfp32>( - khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr); return; } @@ -1547,29 +1562,30 @@ struct FlashAttn { HelperQ8KVR8 khr4(nk1, kh); #endif compute_helper_q, VHelper, FlashQKfp32>( - khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr); return; } #endif } compute_helper_q>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr); } else { typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)]; compute_helper_q>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, (char *)q8); + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, (char *)q8); } } else { compute_helper>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S); } } FlashMS fms; FlashQKV fqkv; + const float * sinkf; }; @@ -1927,7 +1943,7 @@ struct FlashAttnBF16 { static_assert(k_step%32 == 0); static_assert(q_step <= 4 || q_step%4 == 0); - FlashAttnBF16(float scale, float softcap) : fms(scale, softcap) {} + FlashAttnBF16(float scale, float softcap, const float * sinkf) : fms(scale, softcap), sinkf(sinkf) {} template void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, @@ -1967,7 +1983,7 @@ struct FlashAttnBF16 { #if FA_TIMING t1 = Perf::cur_time(); #endif - fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); + fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S); #if FA_TIMING perf.accum_nolock(4, t1); #endif @@ -1990,7 +2006,7 @@ struct FlashAttnBF16 { vh.next_block(k_step); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S); } #if FA_TIMING Perf::instance().add(perf); @@ -1999,12 +2015,14 @@ struct FlashAttnBF16 { FlashMS fms; FlashQKV fqkv; + const float * sinkf; }; #endif template inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, - const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { + const float * q, const char * mask, float scale, float softcap, float * qkv, + const float * sinkf, float * M, float * S) { auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) { nq1 -= n; @@ -2018,48 +2036,48 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str if (nk1 >= 512) { if (nq1 >= 128) { int n_step = nq1/128; - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap, sinkf); fa.compute(kh, vh, 128*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); if (update(128*n_step)) return; } if (nq1 >= 64) { int n_step = nq1/64; - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap, sinkf); fa.compute(kh, vh, 64*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); if (update(64*n_step)) return; } if (nq1 >= 32) { int n_step = nq1/32; - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap, sinkf); fa.compute(kh, vh, 32*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); if (update(32*n_step)) return; } if (nq1 >= 16) { int n_step = nq1/16; - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap, sinkf); fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); if (update(16*n_step)) return; } } if (nq1 >= 8) { int n_step = nq1/8; - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap, sinkf); fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); if (update(8*n_step)) return; } else if (nq1 >= 4) { int n_step = nq1/4; - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap, sinkf); fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); if (update(4*n_step)) return; } else if (nq1 >= 2) { int n_step = nq1/2; - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap, sinkf); fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); if (update(2*n_step)) return; } - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap, sinkf); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } @@ -2067,26 +2085,26 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str template inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, - float scale, float softcap, float * qkv, float * M, float * S) { + float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) { HelperBF16 kh(k, stride_k); HelperBF16 vh(v, stride_v); if (nk1 >= 4096) { if (nq1 >= 64) { - FlashAttnBF16 fa(scale, softcap); + FlashAttnBF16 fa(scale, softcap, sinkf); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); return; } else if (nq1 >= 16) { - FlashAttnBF16 fa(scale, softcap); + FlashAttnBF16 fa(scale, softcap, sinkf); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); return; } } if (nq1 >= 8) { - FlashAttnBF16 fa(scale, softcap); + FlashAttnBF16 fa(scale, softcap, sinkf); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } else { - FlashAttnBF16 fa(scale, softcap); + FlashAttnBF16 fa(scale, softcap, sinkf); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } } @@ -2096,43 +2114,43 @@ template inline bool iqk_flash_helper_T(KHelper& kh, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv, const float * q, const char * v, const char * mask, - float scale, float softcap, float * qkv, float * M, float * S) { + float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) { switch (type_v) { case GGML_TYPE_F16: { HelperF16 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); } break; #ifdef __AVX512BF16__ case GGML_TYPE_BF16: { HelperBF16 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); } break; #endif case GGML_TYPE_Q8_0: { HelperQ80 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_Q8_KV: { HelperQ8KV vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_Q6_0: { HelperQ60 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); } break; #if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_Q4_1: { HelperQ41 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_IQ4_NL: { HelperIQ4nl vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); } break; #endif default: return false; @@ -2144,42 +2162,42 @@ template inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, - float scale, float softcap, float * qkv, float * M, float * S) { + float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) { bool result = false; switch (type_k) { case GGML_TYPE_F16: { HelperF16 kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_Q8_0: { HelperQ80 kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_Q8_0_R8: { HelperQ80R8 kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_Q6_0: { HelperQ60 kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S); } break; #if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q8_KV: { HelperQ8KV kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_Q4_0: { HelperQ40 kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_Q4_1: { HelperQ41 kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_IQ4_NL: { HelperIQ4nl kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S); } break; #endif default: break; @@ -2194,7 +2212,7 @@ inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,\ const float * q, const void * k, const void * v, const void * mask,\ float scale, float softcap,\ - float * qkv, float * M, float * S) + float * qkv, const float * sinkf, float * M, float * S) IQK_FA_CASE(iqk_fa_576_512); IQK_FA_CASE(iqk_fa_192_128); diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index 9a974ae74..19e6cd25d 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -66,6 +66,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements float scale, // scale applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv, // v*softmax(scale*(k*q)) @@ -139,7 +140,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float auto work_this_thread = (float *)(result_buffer + ith*size_thread); if (!iqk_flash_attn_impl(int_type_k, int_type_v, Dk, Dv, nq_this_thread, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv, - (const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, + (const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, nullptr, 0, scale, softcap, work_this_thread, work_this_thread + (Dv+0)*nq_this_thread, work_this_thread + (Dv+1)*nq_this_thread)) return false; @@ -182,51 +183,6 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) { auto result_size = (Dv + 16)*rk2*sizeof(float); int gcd = simple_gcd(nek2, nth); - if (false && gcd > 1) { - int nth_g = nth/gcd; - int ith_g = ith%nth_g; - int nek1_32 = nek1/32; - int nek1_pt = (nek1_32 + nth_g - 1)/nth_g; - int ith_mid = nth_g; - if (nek1_pt*nth_g > nek1_32) { - ith_mid = nek1_32 - nth_g*(nek1_pt - 1); - } - nek1_pt *= 32; - int nek1_mid = ith_mid*nek1_pt; - int nek1_thread = ith_g < ith_mid ? nek1_pt : nek1_pt - 32; - for (int ik02 = ith/nth_g; ik02 < nek2; ik02 += gcd) { - int ik01 = ith_g < ith_mid ? ith_g*nek1_pt : nek1_mid + (ith_g - ith_mid)*nek1_thread; - auto this_result = (float *)((char *)work_buffer + (ik02*nth_g + ith_g)*result_size); - auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2); - auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2; - auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2; - auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here - if (!iqk_flash_attn_impl(int_type_k, int_type_v, - Dk, Dv, rk2, nek1_thread, nbq2, stride_k, stride_v, 0, Dv, - this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, - scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false; - } - - barrier(barrier_data); - - for (int iq2 = ith; iq2 < neq2; iq2 += nth) { - int ik02 = iq2/rk2; - int il = iq2 - ik02*rk2; - auto Racc = qkv + iq2*nb1/sizeof(float); - float M = -INFINITY, S = 0; - for (int ig = 0; ig < nth_g; ++ig) { - int istep_k = ik02*nth_g + ig; - auto this_result = (float *)((char *)work_buffer + istep_k*result_size); - const float * R = this_result + il*Dv; - const float * Mj = this_result + Dv*rk2; - const float * Sj = Mj + rk2; - accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R); - } - float norm = S > 0 ? 1/S : 1; - for (int i = 0; i < Dv; ++i) Racc[i] *= norm; - } - return true; - } int nth_k = nth/gcd; int nek2_k = nek2/gcd; int nchunk = nek2_k*nek1/32; @@ -259,7 +215,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here if (!iqk_flash_attn_impl(int_type_k, int_type_v, Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv, - this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, + this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, nullptr, 0, scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false; } @@ -281,6 +237,16 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float const float * Sj = Mj + rk2; accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R); } + if (sinks) { + float s = ((const float *)sinks)[iq2]; + if (s > M) { + float m = expf(M - s); + for (int i = 0; i < Dv; ++i) Racc[i] *= m; + S = S*m + 1; + } else { + S += expf(s - M); + } + } float norm = S > 0 ? 1/S : 1; for (int i = 0; i < Dv; ++i) Racc[i] *= norm; } @@ -306,6 +272,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float int counter = 0; for (int64_t iq3 = 0; iq3 < neq3; iq3++) { for (int64_t iq2 = 0; iq2 < neq2; iq2++) { + auto sinksf = sinks ? (const float *)sinks + iq2 : nullptr; if (counter++ % (nth/ntg) == ith/ntg) { int iq1 = (ith%ntg)*neq1g; int this_neq1 = std::min(neq1g, neq1-iq1); @@ -314,7 +281,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float (const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q), (const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3), (const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3), - (const void *)((const char *)mask + iq1*stride_m), + (const void *)((const char *)mask + iq1*stride_m), sinksf, 1, scale, softcap, (float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false; } diff --git a/ggml/src/iqk/iqk_flash_impl.h b/ggml/src/iqk/iqk_flash_impl.h index 6f62e56b1..8db97731b 100644 --- a/ggml/src/iqk/iqk_flash_impl.h +++ b/ggml/src/iqk/iqk_flash_impl.h @@ -23,6 +23,8 @@ bool iqk_flash_attn_impl(int type_k, // type of k const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + const float * sinksf, // attention sinks + int nsinks, // number of sinks float scale, // scale applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv, // v*softmax(scale*(k*q)) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index f8624bfc2..041ad1650 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -120,16 +120,21 @@ struct MulMat { funcs[n_left-1](n, vx, bx, info, nrc_x); } } - inline void gelu(int n, const float * src, float * dst); - inline void relu(int n, const float * src, float * dst); - inline void silu(int n, const float * src, float * dst); - inline void activate(ggml_unary_op op, int n, const float * src, float * dst) { + inline static void gelu(int n, const float * src, float * dst); + inline static void relu(int n, const float * src, float * dst); + inline static void silu(int n, const float * src, float * dst); + inline static void swiglu_oai(int n, const float * src, float * dst); + inline static void clamp_oai(int n, float *x); + inline static void activate(ggml_unary_op op, int n, const float * src, float * dst) { if (op == GGML_UNARY_OP_GELU) gelu(n, src, dst); else if (op == GGML_UNARY_OP_RELU) relu(n, src, dst); else if (op == GGML_UNARY_OP_SILU) silu(n, src, dst); + else if (op == GGML_UNARY_OP_SWIGLU_OAI) swiglu_oai(n, src, dst); else GGML_ABORT("fatal error"); } - inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx, DataInfo& info, int nrc_x, int nrc_y, int unary_op) { + inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx, + const float * up_b, const float * gate_b, + DataInfo& info, int nrc_x, int nrc_y, int unary_op) { #ifdef __aarch64__ constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small) #else @@ -137,6 +142,29 @@ struct MulMat { #endif auto op = ggml_unary_op(unary_op); float tmp[k_x_step*16]; + auto process = [&tmp, n, op, vx_gate, vx_up, gate_b, up_b, bx, xstep = k_x_step] (mul_mat_t func, const DataInfo& this_info, int ix, int this_nrc_x, int ny) { + func(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny; ++ky) { + if (gate_b) { + auto b = gate_b + ix; + auto x = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) x[j] += b[j]; + } + activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*xstep); + } + func(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny; ++ky) { + auto result = this_info.dst_row(ky); + if (up_b) { + auto b = up_b + ix; + for (int j = 0; j < this_nrc_x; ++j) result[j] += b[j]; + } + if (op == GGML_UNARY_OP_SWIGLU_OAI) { + clamp_oai(this_nrc_x, result); + } + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*xstep + j]; + } + }; if (func16 && nrc_y >= 16) { int n_step = (nrc_y - info.cur_y)/16; for (int ix = 0; ix < nrc_x; ix += k_x_step) { @@ -144,15 +172,7 @@ struct MulMat { this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; for (int iy = 0; iy < n_step; ++iy) { - func16(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < 16; ++ky) { - activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); - } - func16(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < 16; ++ky) { - auto result = this_info.dst_row(ky); - for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; - } + process(func16, this_info, ix, this_nrc_x, 16); this_info.cur_y += 16; } } @@ -175,23 +195,11 @@ struct MulMat { this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; for (int iy = 0; iy < my1; ++iy) { - funcs[ny1-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny1; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); - funcs[ny1-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny1; ++ky) { - auto result = this_info.dst_row(ky); - for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; - } + process(funcs[ny1-1], this_info, ix, this_nrc_x, ny1); this_info.cur_y += ny1; } for (int iy = 0; iy < my2; ++iy) { - funcs[ny2-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny2; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); - funcs[ny2-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny2; ++ky) { - auto result = this_info.dst_row(ky); - for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; - } + process(funcs[ny2-1], this_info, ix, this_nrc_x, ny2); this_info.cur_y += ny2; } } @@ -203,13 +211,7 @@ struct MulMat { this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; for (int iy = 0; iy < n_step; ++iy) { - funcs[ny-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); - funcs[ny-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny; ++ky) { - auto result = this_info.dst_row(ky); - for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; - } + process(funcs[ny-1], this_info, ix, this_nrc_x, ny); this_info.cur_y += ny; } } @@ -222,13 +224,7 @@ struct MulMat { auto this_info = info; this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; - funcs[n_left-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < n_left; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); - funcs[n_left-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < n_left; ++ky) { - auto result = this_info.dst_row(ky); - for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; - } + process(funcs[n_left-1], this_info, ix, this_nrc_x, n_left); } } } @@ -731,6 +727,7 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, int typeA, const void * Aup, const void * Agate, long strideA, int typeB, const void * B, long strideB, + const char * up_b_c, const char * gate_b_c, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping; @@ -774,7 +771,9 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n if (!iqk_convert_repack(typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) { GGML_ABORT("Fatal error"); } - mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, this_info, this_nrc_x, Ny, unary_op); + auto up_b = up_b_c ? (const float *)up_b_c + first_x + ix : nullptr; + auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x + ix : nullptr; + mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, up_b, gate_b, this_info, this_nrc_x, Ny, unary_op); } return true; @@ -795,7 +794,10 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n nrc_x *= num_rows; DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; - mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny, unary_op); + auto up_b = up_b_c ? (const float *)up_b_c + first_x : nullptr; + auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x : nullptr; + mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx, + up_b, gate_b, info, nrc_x, Ny, unary_op); return true; } @@ -993,6 +995,46 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { namespace { +// TODO: these swiglu_oai constants shouldn't be hard coded +constexpr float k_swiglu_oai_alpha = 1.702f; +constexpr float k_swiglu_oai_limit = 7.f; + +void MulMat::swiglu_oai(int n, const float * x, float * y) { +// int i = 0; +//#if defined __AVX512F__ && defined __AVX512DQ__ +// { +// auto max = _mm512_set1_ps(k_swiglu_oai_limit); +// auto alpha = _mm512_set1_ps(-k_swiglu_oai_alpha); +// for (; i + 15 < n; i += 16) { +// auto xc = v_clamp_max(_mm512_loadu_ps(x + i), max); +// _mm512_storeu_ps(y + i, v_silu_oai(xc, alpha)); +// } +// } +//#endif +//#if defined __AVX2__ && defined __FMA__ +// if (i + 7 < n) { +// auto max = _mm256_set1_ps(k_swiglu_oai_limit); +// auto alpha = _mm256_set1_ps(-k_swiglu_oai_alpha); +// for (; i + 7 < n; i += 8) { +// auto xc = v_clamp_max(_mm256_loadu_ps(x + i), max); +// _mm256_storeu_ps(y + i, v_silu_oai(xc, alpha)); +// } +// } +//#endif +// for (; i < n; ++i) { +// auto xi = std::min(x[i], k_swiglu_oai_limit); +// y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha)); +// } + for (int i = 0; i < n; ++i) { + auto xi = std::min(x[i], k_swiglu_oai_limit); + y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha)); + } +} + +void MulMat::clamp_oai(int n, float * x) { + for (int i = 0; i < n; ++i) x[i] = 1.f + std::max(std::min(x[i], k_swiglu_oai_limit), -k_swiglu_oai_limit); +} + #if defined(__ARM_NEON) && defined(__aarch64__) void MulMat::gelu(int n, const float * x, float * y) { constexpr float GELU_COEF_A = 0.044715f; @@ -1040,6 +1082,37 @@ void MulMat::gelu(int n, const float * x, float * y) { for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i]))); } +//void MulMat::swiglu_oai(int n, const float * x, float * y) { +// int i = 0; +//#if defined __AVX512F__ && defined __AVX512DQ__ +// { +// auto limit = _mm512_set1_ps(k_swiglu_oai_limit); +// auto alpha = _mm512_set1_ps(k_swiglu_oai_alpha); +// for (; i + 15 < n; i += 16) { +// auto xi = _mm512_loadu_ps(x + i); +// auto mask = _mm512_cmp +// +// } +// __m512 c1 = _mm512_set1_ps(GELU_COEF_A); +// __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI); +// for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_gelu(_mm512_loadu_ps(x + i), c1, c2)); +// } +//#endif +//#if defined __AVX2__ && defined __FMA__ +// if (i + 7 < n) { +// __m256 c1 = _mm256_set1_ps(GELU_COEF_A); +// __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI); +// for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_gelu(_mm256_loadu_ps(x + i), c1, c2)); +// +// } +//#endif +// for (; i < n; ++i) { +// auto xi = std::min(x[i], k_swiglu_oai_limit); +// y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha)); +// } +//} + + void MulMat::silu(int n, const float * x, float * y) { int i = 0; #if defined __AVX512F__ && defined __AVX512DQ__ @@ -1188,6 +1261,8 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + const float * sinksf, // mask. If not null, assumed to be fp16. nq x nk elements + [[maybe_unused]] int nsinks, float scale, // scale applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv, // v*softmax(scale*(k*q)) @@ -1197,32 +1272,32 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k if (Dk == 576 && Dv == 512) { return iqk_fa_576_512(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, k, v, mask, scale, softcap, qkv, M, S); + q, k, v, mask, scale, softcap, qkv, sinksf, M, S); } if (Dk == 192 && Dv == 128) { return iqk_fa_192_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, k, v, mask, scale, softcap, qkv, M, S); + q, k, v, mask, scale, softcap, qkv, sinksf, M, S); } if (Dk == 256 && Dv == 256) { return iqk_fa_256_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, k, v, mask, scale, softcap, qkv, M, S); + q, k, v, mask, scale, softcap, qkv, sinksf, M, S); } if (Dk == 128 && Dv == 128) { return iqk_fa_128_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, k, v, mask, scale, softcap, qkv, M, S); + q, k, v, mask, scale, softcap, qkv, sinksf, M, S); } if (Dk == 96 && Dv == 96) { return iqk_fa_96_96(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, k, v, mask, scale, softcap, qkv, M, S); + q, k, v, mask, scale, softcap, qkv, sinksf, M, S); } if (Dk == 64 && Dv == 64) { return iqk_fa_64_64(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, k, v, mask, scale, softcap, qkv, M, S); + q, k, v, mask, scale, softcap, qkv, sinksf, M, S); } return false; diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index 87722f6fe..b131095b0 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -32,6 +32,7 @@ IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, int typeA, const void * Aup, const void * Agate, long strideA, int typeB, const void * B, long strideB, + const char * up_b, const char * gate_b, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); IQK_API int iqk_dequant_type(int type, int Ny); @@ -57,6 +58,7 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements float scale, // scale applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv, // v*softmax(scale*(k*q)) diff --git a/ggml/src/iqk/iqk_utils.h b/ggml/src/iqk/iqk_utils.h index 194bf9b81..435ae4dd2 100644 --- a/ggml/src/iqk/iqk_utils.h +++ b/ggml/src/iqk/iqk_utils.h @@ -61,6 +61,13 @@ static inline float32x4_t v_silu(float32x4_t x) { const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); return vdivq_f32(x, one_plus_exp_neg_x); } +static inline float32x4_t v_silu_oai(float32x4_t x, float32x4_t alpha) { + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t neg_x = vmulq_f32(alpha, x); + const float32x4_t exp_neg_x = v_expf(neg_x); + const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); + return vdivq_f32(x, one_plus_exp_neg_x); +} static inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) { const float32x4_t one = vdupq_n_f32(1.0f); float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x)); @@ -131,6 +138,17 @@ static inline __m512 v_silu(__m512 x) { const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); return _mm512_div_ps(x, one_plus_exp_neg_x); } +static inline __m512 v_silu_oai(__m512 x, __m512 alpha) { + const __m512 one = _mm512_set1_ps(1); + const __m512 neg_x = _mm512_mul_ps(alpha, x); + const __m512 exp_neg_x = v_expf(neg_x); + const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); + return _mm512_div_ps(x, one_plus_exp_neg_x); +} +static inline __m512 v_clamp_max(__m512 x, __m512 max) { + auto mask = _mm512_cmp_ps_mask(x, max, _CMP_GT_OQ); + return _mm512_mask_blend_ps(mask, x, max); +} #endif // __AVX512__ #if defined(__AVX2__) && defined(__FMA__) @@ -195,12 +213,23 @@ static inline __m256 v_gelu(__m256 x, __m256 c1, __m256 c2) { } static inline __m256 v_silu(__m256 x) { const __m256 one = _mm256_set1_ps(1); - const __m256 zero = _mm256_setzero_ps(); + const __m256 zero = _mm256_setzero_ps(); const __m256 neg_x = _mm256_sub_ps(zero, x); const __m256 exp_neg_x = v_expf(neg_x); const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); return _mm256_div_ps(x, one_plus_exp_neg_x); } +static inline __m256 v_silu_oai(__m256 x, __m256 alpha) { + const __m256 one = _mm256_set1_ps(1); + const __m256 neg_x = _mm256_mul_ps(alpha, x); + const __m256 exp_neg_x = v_expf(neg_x); + const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); + return _mm256_div_ps(x, one_plus_exp_neg_x); +} +static inline __m256 v_clamp_max(__m256 x, __m256 max) { + auto mask = _mm256_cmp_ps(x, max, _CMP_GT_OQ); + return _mm256_or_ps(_mm256_and_ps(mask, max), _mm256_andnot_ps(mask, x)); +} #endif // __AVX2__ diff --git a/include/llama.h b/include/llama.h index c68fa2298..482f02b15 100644 --- a/include/llama.h +++ b/include/llama.h @@ -70,50 +70,52 @@ extern "C" { typedef int32_t llama_seq_id; enum llama_vocab_type { - LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab - LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback - LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE - LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece - LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram + LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab + LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback + LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE + LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece + LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram + LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization + LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming }; // pre-tokenization types - enum llama_vocab_pre_type { - LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, - LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, - LLAMA_VOCAB_PRE_TYPE_FALCON = 4, - LLAMA_VOCAB_PRE_TYPE_MPT = 5, - LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, - LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, - LLAMA_VOCAB_PRE_TYPE_REFACT = 8, - LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, - LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, - LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, - LLAMA_VOCAB_PRE_TYPE_OLMO = 12, - LLAMA_VOCAB_PRE_TYPE_DBRX = 13, - LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, - LLAMA_VOCAB_PRE_TYPE_PORO = 15, - LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, - LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, - LLAMA_VOCAB_PRE_TYPE_VIKING = 18, - LLAMA_VOCAB_PRE_TYPE_JAIS = 19, - LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, - LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, - LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, //llama.cpp lists this as 28 - LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, - LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, - LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, - LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, - LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, - LLAMA_VOCAB_PRE_TYPE_FALCON_3 = 34, - LLAMA_VOCAB_PRE_TYPE_FALCON_E = 35, - LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 36, //llama.cpp lists this as 35 - LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 37, //llama.cpp lists this as 36 - LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 38, //llama.cpp lists this as 37 - }; + //enum llama_vocab_pre_type { + // LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, + // LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, + // LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, + // LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, + // LLAMA_VOCAB_PRE_TYPE_FALCON = 4, + // LLAMA_VOCAB_PRE_TYPE_MPT = 5, + // LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, + // LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, + // LLAMA_VOCAB_PRE_TYPE_REFACT = 8, + // LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, + // LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, + // LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, + // LLAMA_VOCAB_PRE_TYPE_OLMO = 12, + // LLAMA_VOCAB_PRE_TYPE_DBRX = 13, + // LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, + // LLAMA_VOCAB_PRE_TYPE_PORO = 15, + // LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, + // LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, + // LLAMA_VOCAB_PRE_TYPE_VIKING = 18, + // LLAMA_VOCAB_PRE_TYPE_JAIS = 19, + // LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, + // LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, + // LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + // LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, //llama.cpp lists this as 28 + // LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, + // LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + // LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + // LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + // LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, + // LLAMA_VOCAB_PRE_TYPE_FALCON_3 = 34, + // LLAMA_VOCAB_PRE_TYPE_FALCON_E = 35, + // LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 36, //llama.cpp lists this as 35 + // LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 37, //llama.cpp lists this as 36 + // LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 38, //llama.cpp lists this as 37 + //}; // note: these values should be synchronized with ggml_rope // TODO: maybe move this enum to ggml.h (ggml_rope_type) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f86137c43..0d5ab7357 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -17,6 +17,8 @@ add_library(llama llama-vocab.cpp llama-grammar.cpp llama-sampling.cpp + llama-mmap.cpp + llama-model-loader.cpp unicode.h unicode.cpp unicode-data.cpp diff --git a/src/llama-arch.h b/src/llama-arch.h new file mode 100644 index 000000000..76ac44c48 --- /dev/null +++ b/src/llama-arch.h @@ -0,0 +1,288 @@ +#pragma once + +#include + +enum llm_arch { + LLM_ARCH_LLAMA, + LLM_ARCH_LLAMA4, + LLM_ARCH_DECI, + LLM_ARCH_FALCON, + LLM_ARCH_BAICHUAN, + LLM_ARCH_GROK, + LLM_ARCH_GPT2, + LLM_ARCH_GPTJ, + LLM_ARCH_GPTNEOX, + LLM_ARCH_MPT, + LLM_ARCH_STARCODER, + LLM_ARCH_REFACT, + LLM_ARCH_BERT, + LLM_ARCH_NOMIC_BERT, + LLM_ARCH_JINA_BERT_V2, + LLM_ARCH_BLOOM, + LLM_ARCH_STABLELM, + LLM_ARCH_QWEN, + LLM_ARCH_QWEN2, + LLM_ARCH_QWEN2MOE, + LLM_ARCH_QWEN3, + LLM_ARCH_QWEN3MOE, + LLM_ARCH_PHI2, + LLM_ARCH_PHI3, + LLM_ARCH_PLAMO, + LLM_ARCH_CODESHELL, + LLM_ARCH_ORION, + LLM_ARCH_INTERNLM2, + LLM_ARCH_MINICPM, + LLM_ARCH_GEMMA, + LLM_ARCH_GEMMA2, + LLM_ARCH_GEMMA3, + LLM_ARCH_STARCODER2, + LLM_ARCH_MAMBA, + LLM_ARCH_XVERSE, + LLM_ARCH_COMMAND_R, + LLM_ARCH_DBRX, + LLM_ARCH_OLMO, + LLM_ARCH_OPENELM, + LLM_ARCH_ARCTIC, + LLM_ARCH_DEEPSEEK2, + LLM_ARCH_CHATGLM, + LLM_ARCH_GLM4, + LLM_ARCH_GLM4_MOE, + LLM_ARCH_BITNET, + LLM_ARCH_BITNET_25, + LLM_ARCH_BITNET_B158, + LLM_ARCH_T5, + LLM_ARCH_T5ENCODER, + LLM_ARCH_JAIS, + LLM_ARCH_GRANITE, + LLM_ARCH_GRANITE_MOE, + LLM_ARCH_COHERE2, + LLM_ARCH_DOTS1, + LLM_ARCH_HUNYUAN_MOE, + LLM_ARCH_OPENAI_MOE, + LLM_ARCH_UNKNOWN, +}; + +enum llm_kv { + LLM_KV_GENERAL_TYPE, + LLM_KV_GENERAL_ARCHITECTURE, + LLM_KV_GENERAL_QUANTIZATION_VERSION, + LLM_KV_GENERAL_ALIGNMENT, + LLM_KV_GENERAL_NAME, + LLM_KV_GENERAL_AUTHOR, + LLM_KV_GENERAL_VERSION, + LLM_KV_GENERAL_URL, + LLM_KV_GENERAL_DESCRIPTION, + LLM_KV_GENERAL_LICENSE, + LLM_KV_GENERAL_SOURCE_URL, + LLM_KV_GENERAL_SOURCE_HF_REPO, + + LLM_KV_VOCAB_SIZE, + LLM_KV_CONTEXT_LENGTH, + LLM_KV_EMBEDDING_LENGTH, + LLM_KV_BLOCK_COUNT, + LLM_KV_LEADING_DENSE_BLOCK_COUNT, + LLM_KV_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, + LLM_KV_USE_PARALLEL_RESIDUAL, + LLM_KV_TENSOR_DATA_LAYOUT, + LLM_KV_EXPERT_COUNT, + LLM_KV_EXPERT_USED_COUNT, + LLM_KV_EXPERT_SHARED_COUNT, + LLM_KV_EXPERT_WEIGHTS_SCALE, + LLM_KV_EXPERT_WEIGHTS_NORM, + LLM_KV_EXPERT_GATING_FUNC, + LLM_KV_NEXTN_PREDICT_LAYERS, + LLM_KV_POOLING_TYPE, + LLM_KV_LOGIT_SCALE, + LLM_KV_DECODER_START_TOKEN_ID, + LLM_KV_ATTN_LOGIT_SOFTCAPPING, + LLM_KV_FINAL_LOGIT_SOFTCAPPING, + LLM_KV_SWIN_NORM, + LLM_KV_RESCALE_EVERY_N_LAYERS, + LLM_KV_TIME_MIX_EXTRA_DIM, + LLM_KV_TIME_DECAY_EXTRA_DIM, + LLM_KV_RESIDUAL_SCALE, + LLM_KV_EMBEDDING_SCALE, + LLM_KV_TOKEN_SHIFT_COUNT, + LLM_KV_INTERLEAVE_MOE_LAYER_STEP, + + LLM_KV_ATTENTION_HEAD_COUNT, + LLM_KV_ATTENTION_HEAD_COUNT_KV, + LLM_KV_ATTENTION_MAX_ALIBI_BIAS, + LLM_KV_ATTENTION_CLAMP_KQV, + LLM_KV_ATTENTION_KEY_LENGTH, + LLM_KV_ATTENTION_VALUE_LENGTH, + LLM_KV_ATTENTION_LAYERNORM_EPS, + LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, + LLM_KV_ATTENTION_CAUSAL, + LLM_KV_ATTENTION_Q_LORA_RANK, + LLM_KV_ATTENTION_KV_LORA_RANK, + LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, + LLM_KV_ATTENTION_SLIDING_WINDOW, + LLM_KV_ATTENTION_SCALE, + + LLM_KV_ROPE_DIMENSION_COUNT, + LLM_KV_ROPE_FREQ_BASE, + LLM_KV_ROPE_SCALE_LINEAR, + LLM_KV_ROPE_SCALING_TYPE, + LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_ATTN_FACTOR, + LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, + LLM_KV_ROPE_SCALING_FINETUNED, + LLM_KV_ROPE_SCALING_YARN_LOG_MUL, + + LLM_KV_SPLIT_NO, + LLM_KV_SPLIT_COUNT, + LLM_KV_SPLIT_TENSORS_COUNT, + + LLM_KV_SSM_INNER_SIZE, + LLM_KV_SSM_CONV_KERNEL, + LLM_KV_SSM_STATE_SIZE, + LLM_KV_SSM_TIME_STEP_RANK, + + LLM_KV_TOKENIZER_MODEL, + LLM_KV_TOKENIZER_PRE, + LLM_KV_TOKENIZER_LIST, + LLM_KV_TOKENIZER_TOKEN_TYPE, + LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, + LLM_KV_TOKENIZER_SCORES, + LLM_KV_TOKENIZER_MERGES, + LLM_KV_TOKENIZER_BOS_ID, + LLM_KV_TOKENIZER_EOS_ID, + LLM_KV_TOKENIZER_UNK_ID, + LLM_KV_TOKENIZER_SEP_ID, + LLM_KV_TOKENIZER_PAD_ID, + LLM_KV_TOKENIZER_CLS_ID, + LLM_KV_TOKENIZER_MASK_ID, + LLM_KV_TOKENIZER_ADD_BOS, + LLM_KV_TOKENIZER_ADD_EOS, + LLM_KV_TOKENIZER_ADD_SEP, + LLM_KV_TOKENIZER_ADD_PREFIX, + LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, + LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, + LLM_KV_TOKENIZER_HF_JSON, + LLM_KV_TOKENIZER_RWKV, + LLM_KV_TOKENIZER_CHAT_TEMPLATE, + LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, + LLM_KV_TOKENIZER_FIM_PRE_ID, + LLM_KV_TOKENIZER_FIM_SUF_ID, + LLM_KV_TOKENIZER_FIM_MID_ID, + LLM_KV_TOKENIZER_FIM_PAD_ID, + LLM_KV_TOKENIZER_FIM_REP_ID, + LLM_KV_TOKENIZER_FIM_SEP_ID, + LLM_KV_TOKENIZER_PREFIX_ID, + LLM_KV_TOKENIZER_SUFFIX_ID, + LLM_KV_TOKENIZER_MIDDLE_ID, + LLM_KV_TOKENIZER_EOT_ID, + LLM_KV_TOKENIZER_EOM_ID, + + LLM_KV_ADAPTER_TYPE, + LLM_KV_ADAPTER_LORA_ALPHA, +}; + +struct LLM_KV { + LLM_KV(llm_arch arch, const char* suffix = nullptr); + + llm_arch arch; + const char* suffix; + std::string operator()(llm_kv kv) const; +}; + +enum llm_tensor { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_TOKEN_TYPES, + LLM_TENSOR_POS_EMBD, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_NORM_2, + LLM_TENSOR_ATTN_OUT_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_ATTN_SINKS, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_ACT, + LLM_TENSOR_FFN_DOWN_EXP, // split experts for backward compatibility + LLM_TENSOR_FFN_GATE_EXP, + LLM_TENSOR_FFN_UP_EXP, + LLM_TENSOR_FFN_NORM_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, // merged experts + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_SSM_IN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_X, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_OUT, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_SUB_NORM, + LLM_TENSOR_FFN_SUB_NORM, + LLM_TENSOR_DEC_ATTN_NORM, + LLM_TENSOR_DEC_ATTN_Q, + LLM_TENSOR_DEC_ATTN_K, + LLM_TENSOR_DEC_ATTN_V, + LLM_TENSOR_DEC_ATTN_OUT, + LLM_TENSOR_DEC_ATTN_REL_B, + LLM_TENSOR_DEC_CROSS_ATTN_NORM, + LLM_TENSOR_DEC_CROSS_ATTN_Q, + LLM_TENSOR_DEC_CROSS_ATTN_K, + LLM_TENSOR_DEC_CROSS_ATTN_V, + LLM_TENSOR_DEC_CROSS_ATTN_OUT, + LLM_TENSOR_DEC_CROSS_ATTN_REL_B, + LLM_TENSOR_DEC_FFN_NORM, + LLM_TENSOR_DEC_FFN_GATE, + LLM_TENSOR_DEC_FFN_DOWN, + LLM_TENSOR_DEC_FFN_UP, + LLM_TENSOR_DEC_OUTPUT_NORM, + LLM_TENSOR_ENC_ATTN_NORM, + LLM_TENSOR_ENC_ATTN_Q, + LLM_TENSOR_ENC_ATTN_K, + LLM_TENSOR_ENC_ATTN_V, + LLM_TENSOR_ENC_ATTN_OUT, + LLM_TENSOR_ENC_ATTN_REL_B, + LLM_TENSOR_ENC_FFN_NORM, + LLM_TENSOR_ENC_FFN_GATE, + LLM_TENSOR_ENC_FFN_DOWN, + LLM_TENSOR_ENC_FFN_UP, + LLM_TENSOR_ENC_OUTPUT_NORM, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, +}; + +llm_arch llm_arch_from_string(const std::string & name); diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index b123d7331..c5c29d217 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -486,9 +486,9 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; - const std::string & piece = vocab->cache_token_to_piece.at(id); + const std::string & piece = vocab->token_to_piece(id); - if (llama_token_is_eog_impl(*vocab, id)) { + if (vocab->is_eog(id)) { if (!allow_eog) { candidates->data[i].logit = -INFINITY; } @@ -511,7 +511,7 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) { const int64_t t_start_sample_us = ggml_time_us(); - if (llama_token_is_eog_impl(*vocab, token)) { + if (vocab->is_eog(token)) { for (const auto & stack : grammar->stacks) { if (stack.empty()) { return; @@ -520,7 +520,7 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc GGML_ABORT("fatal error"); } - const std::string & piece = vocab->cache_token_to_piece.at(token); + const std::string & piece = vocab->token_to_piece(token); // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece, grammar->partial_utf8); diff --git a/src/llama-impl.h b/src/llama-impl.h index a50f60cfd..cd4e0730a 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -10,6 +10,11 @@ #define LLAMA_API_INTERNAL #include "llama.h" #include +#include +#include +#include +#include +#include #ifdef __GNUC__ #ifdef __MINGW32__ @@ -33,6 +38,7 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void * #define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) #define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) #define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define LLAMA_LOG_DEBUG(...) llama_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) // // helpers @@ -166,3 +172,49 @@ struct ring_buffer { size_t pos = 0; std::vector data; }; + +LLAMA_ATTRIBUTE_FORMAT(1, 2) +static std::string format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + +static std::string llama_format_tensor_shape(const std::vector & ne) { + char buf[256]; + snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0)); + for (size_t i = 1; i < ne.size(); i++) { + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i)); + } + return buf; +} + +static std::string llama_format_tensor_shape(const struct ggml_tensor * t) { + char buf[256]; + snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]); + } + return buf; +} + +template +struct no_init { + T value; + no_init() { /* do nothing */ } +}; + + +struct gguf_context; +std::string gguf_kv_to_str(const gguf_context * ctx_gguf, int i); + +ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer); diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp new file mode 100644 index 000000000..4a65c9cb6 --- /dev/null +++ b/src/llama-mmap.cpp @@ -0,0 +1,650 @@ +#include "llama-mmap.h" + +#include "llama-impl.h" + +#include "ggml.h" + +#include +#include +#include +#include +#include +#include +#include + +#ifdef __has_include + #if __has_include() + #include + #if defined(_POSIX_MAPPED_FILES) + #include + #include + #endif + #if defined(_POSIX_MEMLOCK_RANGE) + #include + #endif + #endif +#endif + +#if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #ifndef NOMINMAX + #define NOMINMAX + #endif + #include + #ifndef PATH_MAX + #define PATH_MAX MAX_PATH + #endif + #include +#endif + +#if defined(__APPLE__) +#include +#endif + +// TODO: consider moving to llama-impl.h if needed in more places +#if defined(_WIN32) +static std::string llama_format_win_err(DWORD err) { + LPSTR buf; + size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL); + if (!size) { + return "FormatMessageA failed"; + } + std::string ret(buf, size); + LocalFree(buf); + return ret; +} +#endif + +// llama_file + +struct llama_file::impl { +#if defined(_WIN32) + HANDLE fp_win32; + std::string GetErrorMessageWin32(DWORD error_code) const { + std::string ret; + LPSTR lpMsgBuf = NULL; + DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL); + if (!bufLen) { + ret = format("Win32 error code: %lx", error_code); + } else { + ret = lpMsgBuf; + LocalFree(lpMsgBuf); + } + + return ret; + } + + impl(const char * fname, const char * mode) { + fp = ggml_fopen(fname, mode); + if (fp == NULL) { + throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + } + fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp)); + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + + size_t tell() const { + LARGE_INTEGER li; + li.QuadPart = 0; + BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + + return li.QuadPart; + } + + void seek(size_t offset, int whence) const { + static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN"); + static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT"); + static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END"); + + LARGE_INTEGER li; + li.QuadPart = offset; + BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + } + + void read_raw(void * ptr, size_t len) const { + size_t bytes_read = 0; + while (bytes_read < len) { + size_t chunk_size = std::min(len - bytes_read, 64*1024*1024); + DWORD chunk_read = 0; + BOOL result = ReadFile(fp_win32, reinterpret_cast(ptr) + bytes_read, chunk_size, &chunk_read, NULL); + if (!result) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + if (chunk_read < chunk_size || chunk_read == 0) { + throw std::runtime_error("unexpectedly reached end of file"); + } + + bytes_read += chunk_read; + } + } + + uint32_t read_u32() const { + uint32_t val; + read_raw(&val, sizeof(val)); + return val; + } + + void write_raw(const void * ptr, size_t len) const { + size_t bytes_written = 0; + while (bytes_written < len) { + size_t chunk_size = std::min(len - bytes_written, 64*1024*1024); + DWORD chunk_written = 0; + BOOL result = WriteFile(fp_win32, reinterpret_cast(ptr) + bytes_written, chunk_size, &chunk_written, NULL); + if (!result) { + throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + if (chunk_written < chunk_size || chunk_written == 0) { + throw std::runtime_error("unexpectedly failed to write bytes"); + } + + bytes_written += chunk_written; + } + } + + void write_u32(uint32_t val) const { + write_raw(&val, sizeof(val)); + } + + ~impl() { + if (fp) { + std::fclose(fp); + } + } +#else + impl(const char * fname, const char * mode) { + fp = ggml_fopen(fname, mode); + if (fp == NULL) { + throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + } + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + + size_t tell() const { +// TODO: this ifdef is never true? +#ifdef _WIN32 + __int64 ret = _ftelli64(fp); +#else + long ret = std::ftell(fp); +#endif + if (ret == -1) { + throw std::runtime_error(format("ftell error: %s", strerror(errno))); + } + + return (size_t) ret; + } + + void seek(size_t offset, int whence) const { +// TODO: this ifdef is never true? +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64) offset, whence); +#else + int ret = std::fseek(fp, (long) offset, whence); +#endif + if (ret != 0) { + throw std::runtime_error(format("seek error: %s", strerror(errno))); + } + } + + void read_raw(void * ptr, size_t len) const { + if (len == 0) { + return; + } + errno = 0; + std::size_t ret = std::fread(ptr, len, 1, fp); + if (ferror(fp)) { + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (ret != 1) { + throw std::runtime_error("unexpectedly reached end of file"); + } + } + + uint32_t read_u32() const { + uint32_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + + void write_raw(const void * ptr, size_t len) const { + if (len == 0) { + return; + } + errno = 0; + size_t ret = std::fwrite(ptr, len, 1, fp); + if (ret != 1) { + throw std::runtime_error(format("write error: %s", strerror(errno))); + } + } + + void write_u32(uint32_t val) const { + write_raw(&val, sizeof(val)); + } + + ~impl() { + if (fp) { + std::fclose(fp); + } + } +#endif + + FILE * fp; + size_t size; +}; + +llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique(fname, mode)) {} +llama_file::~llama_file() = default; + +size_t llama_file::tell() const { return pimpl->tell(); } +size_t llama_file::size() const { return pimpl->size; } + +int llama_file::file_id() const { +#ifdef _WIN32 + return _fileno(pimpl->fp); +#else +#if defined(fileno) + return fileno(pimpl->fp); +#else + return ::fileno(pimpl->fp); +#endif +#endif +} + +void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); } +void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); } + +uint32_t llama_file::read_u32() const { return pimpl->read_u32(); } + +void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); } +void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); } + +// llama_mmap + +struct llama_mmap::impl { +#ifdef _POSIX_MAPPED_FILES + std::vector> mapped_fragments; + + impl(struct llama_file * file, size_t prefetch, bool numa, bool use_thp) { + size = file->size(); + int fd = file->file_id(); + int flags = MAP_SHARED; + if (numa) { prefetch = 0; } +#ifdef __linux__ + if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) { + LLAMA_LOG_WARN("warning: posix_fadvise(.., POSIX_FADV_SEQUENTIAL) failed: %s\n", + strerror(errno)); + } + if (prefetch) { flags |= MAP_POPULATE; } + if (use_thp) { + size_t huge = get_default_huge_page_size(); + auto size = huge*((file->size() + huge - 1)/huge); + addr = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB, -1, 0); + if (addr != MAP_FAILED) { + printf("%s: using THP with page size %zu MiB ", __func__, huge/(1024*1024)); + fflush(stdout); + size_t tot = 0; + while (tot < file->size()) { + auto n_read = pread(fd, static_cast(addr) + tot, file->size() - tot, tot); + if (n_read < 0) throw std::runtime_error(format("Reading into mapped huge pages failed at %zu (%s)", tot, strerror(errno))); + printf("."); fflush(stdout); + tot += n_read; + } + printf(" done\n"); + mapped_fragments.emplace_back(0, file->size()); + mapped_page_size = huge; + return; + } + else { + fprintf(stderr, "%s: mmap with huge page size %zu MiB failed (%s)\n", __func__, huge/(1024*1024), strerror(errno)); + } + } +#endif + addr = mmap(NULL, file->size(), PROT_READ, flags, fd, 0); + if (addr == MAP_FAILED) { + throw std::runtime_error(format("mmap failed: %s", strerror(errno))); + } + + if (prefetch > 0) { + if (posix_madvise(addr, std::min(file->size(), prefetch), POSIX_MADV_WILLNEED)) { + LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n", + strerror(errno)); + } + } + if (numa) { + if (posix_madvise(addr, file->size(), POSIX_MADV_RANDOM)) { + LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n", + strerror(errno)); + } + } + + mapped_fragments.emplace_back(0, file->size()); + } + +#ifdef __linux__ + static int get_default_huge_page_size() { + int pg_size = 2048; + std::ifstream in("/proc/meminfo"); + if (in) { + std::string line; + while (true) { + std::getline(in, line); + if (in.fail()) break; + if (auto pos = line.find("Hugepagesize:"); pos != std::string::npos) { + std::istringstream str(line.data() + pos + 13); + int aux; + str >> aux; + if (!str.fail()) pg_size = aux; + break; + } + } + } + return pg_size * 1024; + } +#endif + + + static void align_range(size_t * first, size_t * last, size_t page_size) { + size_t offset_in_page = *first & (page_size - 1); + size_t offset_to_page = offset_in_page == 0 ? 0 : page_size - offset_in_page; + *first += offset_to_page; + + *last = *last & ~(page_size - 1); + + if (*last <= *first) { + *last = *first; + } + } + + void unmap_fragment(size_t first, size_t last) { + int page_size = mapped_page_size > 0 ? mapped_page_size : sysconf(_SC_PAGESIZE); + align_range(&first, &last, page_size); + size_t len = last - first; + + if (len == 0) { + return; + } + + GGML_ASSERT(first % page_size == 0); + GGML_ASSERT(last % page_size == 0); + GGML_ASSERT(last > first); + + void * next_page_start = (uint8_t *) addr + first; + + if (munmap(next_page_start, len)) { + LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno)); + } + + std::vector> new_mapped_fragments; + for (const auto & frag : mapped_fragments) { + if (frag.first < first && frag.second > last) { + new_mapped_fragments.emplace_back(frag.first, first); + new_mapped_fragments.emplace_back(last, frag.second); + } else if (frag.first < first && frag.second > first) { + new_mapped_fragments.emplace_back(frag.first, first); + } else if (frag.first < last && frag.second > last) { + new_mapped_fragments.emplace_back(last, frag.second); + } else if (frag.first >= first && frag.second <= last) { + } else { + new_mapped_fragments.push_back(frag); + } + } + mapped_fragments = std::move(new_mapped_fragments); + } + + ~impl() { + for (const auto & frag : mapped_fragments) { + if (munmap((char *) addr + frag.first, frag.second - frag.first)) { + LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno)); + } + } + } +#elif defined(_WIN32) + impl(struct llama_file * file, size_t prefetch, bool numa, [[maybe_unused]] bool use_thp) { + GGML_UNUSED(numa); + + size = file->size(); + + HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id()); + + HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + + if (hMapping == NULL) { + DWORD error = GetLastError(); + throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str())); + } + + addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); + DWORD error = GetLastError(); + CloseHandle(hMapping); + + if (addr == NULL) { + throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str())); + } + + if (prefetch > 0) { +#if _WIN32_WINNT >= 0x602 + BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG); + HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll"); + + pPrefetchVirtualMemory = (decltype(pPrefetchVirtualMemory))(void *) GetProcAddress(hKernel32, "PrefetchVirtualMemory"); + + if (pPrefetchVirtualMemory) { + WIN32_MEMORY_RANGE_ENTRY range; + range.VirtualAddress = addr; + range.NumberOfBytes = (SIZE_T) std::min(size, prefetch); + if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) { + LLAMA_LOG_WARN("warning: PrefetchVirtualMemory failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + LLAMA_LOG_DEBUG("skipping PrefetchVirtualMemory because _WIN32_WINNT < 0x602\n"); +#endif + } + } + + void unmap_fragment(size_t first, size_t last) { + GGML_UNUSED(first); + GGML_UNUSED(last); + } + + ~impl() { + if (!UnmapViewOfFile(addr)) { + LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + impl(struct llama_file * file, size_t prefetch, bool numa, [[maybe_unused]] bool use_thp) { + GGML_UNUSED(file); + GGML_UNUSED(prefetch); + GGML_UNUSED(numa); + + throw std::runtime_error("mmap not supported"); + } + + void unmap_fragment(size_t first, size_t last) { + GGML_UNUSED(first); + GGML_UNUSED(last); + + throw std::runtime_error("mmap not supported"); + } +#endif + + void * addr; + size_t size; + size_t mapped_page_size = 0; +}; + +llama_mmap::llama_mmap(struct llama_file * file, size_t prefetch, bool numa, bool use_thp) : + pimpl(std::make_unique(file, prefetch, numa, use_thp)) {} +llama_mmap::~llama_mmap() = default; + +size_t llama_mmap::size() const { return pimpl->size; } +void * llama_mmap::addr() const { return pimpl->addr; } + +void llama_mmap::unmap_fragment(size_t first, size_t last) { pimpl->unmap_fragment(first, last); } + +#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32) +const bool llama_mmap::SUPPORTED = true; +#else +const bool llama_mmap::SUPPORTED = false; +#endif + +// llama_mlock + +struct llama_mlock::impl { +#ifdef _POSIX_MEMLOCK_RANGE + static size_t lock_granularity() { + return (size_t) sysconf(_SC_PAGESIZE); + } + + bool raw_lock(const void * addr, size_t size) const { + if (!mlock(addr, size)) { + return true; + } + +#ifdef __APPLE__ +#define MLOCK_SUGGESTION \ + "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \ + "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MEMLOCK (ulimit -l).\n" +#else +#define MLOCK_SUGGESTION \ + "Try increasing RLIMIT_MEMLOCK ('ulimit -l' as root).\n" +#endif + + char* errmsg = std::strerror(errno); + bool suggest = (errno == ENOMEM); +#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX) + // visionOS/tvOS dont't support RLIMIT_MEMLOCK + // Skip resource limit checks on visionOS/tvOS + suggest = false; +#else + struct rlimit lock_limit; + if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) { + suggest = false; + } + if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) { + suggest = false; + } +#endif + + LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s", + size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : ""); + return false; + } + + static void raw_unlock(void * addr, size_t size) { + if (munlock(addr, size)) { + LLAMA_LOG_WARN("warning: failed to munlock buffer: %s\n", std::strerror(errno)); + } + } +#elif defined(_WIN32) + static size_t lock_granularity() { + SYSTEM_INFO si; + GetSystemInfo(&si); + return (size_t) si.dwPageSize; + } + + bool raw_lock(void * ptr, size_t len) const { + for (int tries = 1; ; tries++) { + if (VirtualLock(ptr, len)) { + return true; + } + if (tries == 2) { + LLAMA_LOG_WARN("warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n", + len, size, llama_format_win_err(GetLastError()).c_str()); + return false; + } + + SIZE_T min_ws_size, max_ws_size; + if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) { + LLAMA_LOG_WARN("warning: GetProcessWorkingSetSize failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + return false; + } + size_t increment = len + 1048576; + min_ws_size += increment; + max_ws_size += increment; + if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) { + LLAMA_LOG_WARN("warning: SetProcessWorkingSetSize failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + return false; + } + } + } + + static void raw_unlock(void * ptr, size_t len) { + if (!VirtualUnlock(ptr, len)) { + LLAMA_LOG_WARN("warning: failed to VirtualUnlock buffer: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + static size_t lock_granularity() { + return (size_t) 65536; + } + + bool raw_lock(const void * addr, size_t len) const { + LLAMA_LOG_WARN("warning: mlock not supported on this system\n"); + return false; + } + + static void raw_unlock(const void * addr, size_t len) {} +#endif + + impl() : addr(NULL), size(0), failed_already(false) {} + + void init(void * ptr) { + GGML_ASSERT(addr == NULL && size == 0); + addr = ptr; + } + + void grow_to(size_t target_size) { + GGML_ASSERT(addr); + if (failed_already) { + return; + } + size_t granularity = lock_granularity(); + target_size = (target_size + granularity - 1) & ~(granularity - 1); + if (target_size > size) { + if (raw_lock((uint8_t *) addr + size, target_size - size)) { + size = target_size; + } else { + failed_already = true; + } + } + } + + void * addr; + size_t size; + + bool failed_already; +}; + +llama_mlock::llama_mlock() : pimpl(std::make_unique()) {} +llama_mlock::~llama_mlock() = default; + +void llama_mlock::init(void * ptr) { pimpl->init(ptr); } +void llama_mlock::grow_to(size_t target_size) { pimpl->grow_to(target_size); } + +#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32) +const bool llama_mlock::SUPPORTED = true; +#else +const bool llama_mlock::SUPPORTED = false; +#endif + +size_t llama_path_max() { + return PATH_MAX; +} diff --git a/src/llama-mmap.h b/src/llama-mmap.h new file mode 100644 index 000000000..a1efa068f --- /dev/null +++ b/src/llama-mmap.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +#include + +struct llama_file; +struct llama_mmap; +struct llama_mlock; + +using llama_files = std::vector>; +using llama_mmaps = std::vector>; +using llama_mlocks = std::vector>; + +struct llama_file { + llama_file(const char * fname, const char * mode); + ~llama_file(); + + size_t tell() const; + size_t size() const; + + int file_id() const; // fileno overload + + void seek(size_t offset, int whence) const; + + void read_raw(void * ptr, size_t len) const; + uint32_t read_u32() const; + + void write_raw(const void * ptr, size_t len) const; + void write_u32(uint32_t val) const; + +private: + struct impl; + std::unique_ptr pimpl; +}; + +struct llama_mmap { + llama_mmap(const llama_mmap &) = delete; + llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false, bool use_thp = false); + ~llama_mmap(); + + size_t size() const; + void * addr() const; + + void unmap_fragment(size_t first, size_t last); + + static const bool SUPPORTED; + +private: + struct impl; + std::unique_ptr pimpl; +}; + +struct llama_mlock { + llama_mlock(); + ~llama_mlock(); + + void init(void * ptr); + void grow_to(size_t target_size); + + static const bool SUPPORTED; + +private: + struct impl; + std::unique_ptr pimpl; +}; + +size_t llama_path_max(); diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp new file mode 100644 index 000000000..915764615 --- /dev/null +++ b/src/llama-model-loader.cpp @@ -0,0 +1,1082 @@ +#include "llama-model-loader.h" +#include "llama-impl.h" +#include "llama-mmap.h" +#include "ggml.h" +//#include "ggml-backend.h" + +#ifdef GGML_USE_CUDA +# include "ggml-cuda.h" +#elif defined(GGML_USE_VULKAN) +# include "ggml-vulkan.h" +#elif defined(GGML_USE_SYCL) +# include "ggml-sycl.h" +#elif defined(GGML_USE_KOMPUTE) +# include "ggml-kompute.h" +#elif defined(GGML_USE_CANN) +# include "ggml-cann.h" +#endif + +#include +#include +#include +#include + +#if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #ifndef NOMINMAX + #define NOMINMAX + #endif + #include + #ifndef PATH_MAX + #define PATH_MAX MAX_PATH + #endif + #include +#endif + +#define LLAMA_API_INTERNAL + +namespace GGUFMeta { + template + struct GKV_Base_Type { + static constexpr gguf_type gt = gt_; + + static T getter(const gguf_context * ctx, const int kid) { + return gfun(ctx, kid); + } + }; + + template struct GKV_Base; + + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + + template<> struct GKV_Base { + static constexpr gguf_type gt = GGUF_TYPE_STRING; + + static std::string getter(const gguf_context * ctx, const int kid) { + return gguf_get_val_str(ctx, kid); + } + }; + + struct ArrayInfo { + const gguf_type gt; + const size_t length; + const void * data; + }; + + template<> struct GKV_Base { + public: + static constexpr gguf_type gt = GGUF_TYPE_ARRAY; + static ArrayInfo getter(const gguf_context *ctx, const int k) { + return ArrayInfo { + gguf_get_arr_type(ctx, k), + size_t(gguf_get_arr_n(ctx, k)), + gguf_get_arr_data(ctx, k), + }; + } + }; + + template + class GKV : public GKV_Base { + GKV() = delete; + + public: + static T get_kv(const gguf_context * ctx, const int k) { + const enum gguf_type kt = gguf_get_kv_type(ctx, k); + + if (kt != GKV::gt) { + throw std::runtime_error(format("key %s has wrong type %s but expected type %s", + gguf_get_key(ctx, k), gguf_type_name(kt), gguf_type_name(GKV::gt))); + } + return GKV::getter(ctx, k); + } + + static const char * override_type_to_str(const llama_model_kv_override_type ty) { + switch (ty) { + case LLAMA_KV_OVERRIDE_TYPE_BOOL: return "bool"; + case LLAMA_KV_OVERRIDE_TYPE_INT: return "int"; + case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float"; + case LLAMA_KV_OVERRIDE_TYPE_STR: return "str"; + } + return "unknown"; + } + + static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override * ovrd) { + if (!ovrd) { return false; } + if (ovrd->tag == expected_type) { + LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ", + __func__, override_type_to_str(ovrd->tag), ovrd->key); + switch (ovrd->tag) { + case LLAMA_KV_OVERRIDE_TYPE_BOOL: { + LLAMA_LOG_INFO("%s\n", ovrd->val_bool ? "true" : "false"); + } break; + case LLAMA_KV_OVERRIDE_TYPE_INT: { + LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->val_i64); + } break; + case LLAMA_KV_OVERRIDE_TYPE_FLOAT: { + LLAMA_LOG_INFO("%.6f\n", ovrd->val_f64); + } break; + case LLAMA_KV_OVERRIDE_TYPE_STR: { + LLAMA_LOG_INFO("%s\n", ovrd->val_str); + } break; + default: + // Shouldn't be possible to end up here, but just in case... + throw std::runtime_error( + format("Unsupported attempt to override %s type for metadata key %s\n", + override_type_to_str(ovrd->tag), ovrd->key)); + } + return true; + } + LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n", + __func__, ovrd->key, override_type_to_str(expected_type), override_type_to_str(ovrd->tag)); + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(OT & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) { + target = ovrd->val_bool; + return true; + } + return false; + } + + template + static typename std::enable_if::value && std::is_integral::value, bool>::type + try_override(OT & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) { + target = ovrd->val_i64; + return true; + } + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(T & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) { + target = ovrd->val_f64; + return true; + } + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(T & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) { + target = ovrd->val_str; + return true; + } + return false; + } + + static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) { + if (try_override(target, ovrd)) { + return true; + } + if (k < 0) { return false; } + target = get_kv(ctx, k); + return true; + } + + static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override * ovrd = nullptr) { + return set(ctx, gguf_find_key(ctx, key), target, ovrd); + } + + static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override * ovrd = nullptr) { + return set(ctx, key.c_str(), target, ovrd); + } + }; +} + +llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp, + const llama_model_kv_override * param_overrides_p, + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) { + int trace = 0; + if (getenv("LLAMA_TRACE")) { + trace = atoi(getenv("LLAMA_TRACE")); + } + +#ifdef _WIN32 + // Only bump maxstdio if the user really wants large contexts: +#if defined(GGML_MAX_CONTEXTS) && (GGML_MAX_CONTEXTS > 512) + // Cap at MSVC's hard limit of 8192 - https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/setmaxstdio?view=msvc-160 +#if (GGML_MAX_CONTEXTS > 8192) +#define _GGML_STDIO_TARGET 8192 +#else +#define _GGML_STDIO_TARGET GGML_MAX_CONTEXTS +#endif + int _setmaxstdio_ret = _setmaxstdio(_GGML_STDIO_TARGET); + if (_setmaxstdio_ret == -1) { + LLAMA_LOG_INFO("%s: failed to set max stdio to %d. (setmaxstdio returned -1)\n", __func__, _GGML_STDIO_TARGET); + } else { + LLAMA_LOG_INFO("%s: max stdio successfully set to %d\n", __func__, _setmaxstdio_ret); + } +#endif // GGML_MAX_CONTEXTS > 512 +#endif // _WIN32 + + if (param_overrides_p != nullptr) { + for (const struct llama_model_kv_override * p = param_overrides_p; p->key[0] != 0; p++) { + kv_overrides.insert({std::string(p->key), *p}); + } + } + + tensor_buft_overrides = param_tensor_buft_overrides_p; + + struct ggml_context * ctx = NULL; + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; + + meta = gguf_init_from_file(fname.c_str(), params); + if (!meta) { + throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str())); + } + + get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); + llm_kv = LLM_KV(llm_arch_from_string(arch_name)); + + files.emplace_back(new llama_file(fname.c_str(), "rb")); + contexts.emplace_back(ctx); + + // Save tensors data offset of the main file. + // For subsidiary files, `meta` tensor data offset must not be used, + // so we build a unified tensors index for weights. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + weights.emplace_back(files.back().get(), 0, cur->name, meta, cur); + } + uint16_t n_split = 0; + get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false); + + // Load additional GGML contexts + if (n_split > 1) { + uint16_t idx = 0; + get_key(llm_kv(LLM_KV_SPLIT_NO), idx); + if (idx != 0) { + throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", idx)); + } + + char split_prefix[PATH_MAX] = {0}; + if (!llama_split_prefix(split_prefix, sizeof(split_prefix), fname.c_str(), idx, n_split)) { + throw std::runtime_error(format("invalid split file: %s", fname.c_str())); + } + + if (trace > 0) { + LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split); + } + + char split_path[PATH_MAX] = {0}; + for (idx = 1; idx < n_split; idx++) { + llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split); + + struct gguf_init_params split_params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; + struct gguf_context * ctx_gguf = gguf_init_from_file(split_path, split_params); + if (!ctx_gguf) { + throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path)); + } + + files.emplace_back(new llama_file(split_path, "rb")); + contexts.emplace_back(ctx); + + // Save tensors data offset info of the shard. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + weights.emplace_back(files.back().get(), idx, cur->name, ctx_gguf, cur); + } + + gguf_free(ctx_gguf); + } + + get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors); + + // sanity check + { + const int n_tensors_loaded = (int) weights.size(); + if (n_tensors != n_tensors_loaded) { + throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded)); + } + } + + LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1); + } + + n_kv = gguf_get_n_kv(meta); + n_tensors = weights.size(); + + fver = (enum llama_fver) gguf_get_version(meta); + + std::set tensor_names; + for (auto & w : weights) { + n_elements += ggml_nelements(w.tensor); + n_bytes += ggml_nbytes(w.tensor); + // make sure there is no duplicated tensor names + const std::string name(w.tensor->name); + auto found = tensor_names.find(name); + if (found != tensor_names.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", w.tensor->name)); + } + tensor_names.insert(name); + } + + LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", + __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver)); + + // determine file type based on the number of tensors for each quantization and print meta data + // TODO: make optional + { + std::map n_type; + + uint32_t n_type_max = 0; + enum ggml_type type_max = GGML_TYPE_F32; + + for (int i = 0; i < n_tensors; i++) { + const ggml_tensor * tensor = weights.at(i).tensor; + enum ggml_type type = tensor->type; + + n_type[type]++; + + if (n_type_max < n_type[type]) { + n_type_max = n_type[type]; + type_max = type; + } + + if (trace > 0) { + const uint16_t sid = weights.at(i).idx; + LLAMA_LOG_INFO("%s: - tensor %4d, split %2d: %32s %-8s [ %s ]\n", __func__, i, sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str()); + } + } + + switch (type_max) { + case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break; + case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break; + case GGML_TYPE_BF16: ftype = LLAMA_FTYPE_MOSTLY_BF16; break; + case GGML_TYPE_BF16_R16:ftype = LLAMA_FTYPE_MOSTLY_BF16_R16;break; + case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break; + case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break; + case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break; + case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break; + case GGML_TYPE_Q6_0: ftype = LLAMA_FTYPE_MOSTLY_Q6_0; break; + case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break; + case GGML_TYPE_Q8_KV: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV; break; + case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break; + case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break; + case GGML_TYPE_Q3_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_R4; break; + case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break; + case GGML_TYPE_Q4_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_R4; break; + case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break; + case GGML_TYPE_Q5_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_R4; break; + case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; + case GGML_TYPE_Q6_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_K_R4; break; + case GGML_TYPE_Q8_K_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_K_R8; break; + case GGML_TYPE_Q8_KV_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV_R8; break; + case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break; + case GGML_TYPE_IQ2_XXS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4; break; + case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break; + case GGML_TYPE_IQ2_XS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS_R4; break; + case GGML_TYPE_IQ2_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_KS; break; + case GGML_TYPE_IQ2_S: ftype = LLAMA_FTYPE_MOSTLY_IQ2_M; break; + case GGML_TYPE_IQ2_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_M_R4;break; + case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break; + case GGML_TYPE_IQ3_XXS_R4: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4; break; + case GGML_TYPE_IQ1_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ1_KT; break; + case GGML_TYPE_IQ2_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ2_KT; break; + case GGML_TYPE_IQ3_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ3_KT; break; + case GGML_TYPE_IQ4_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KT; break; + case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break; + case GGML_TYPE_IQ1_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ1_S_R4;break; + case GGML_TYPE_IQ1_M_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ1_M_R4;break; + case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break; + case GGML_TYPE_IQ1_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ1_BN; break; + case GGML_TYPE_IQ2_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN; break; + case GGML_TYPE_IQ2_BN_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN_R4;break; + case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; + case GGML_TYPE_IQ4_NL_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL_R4;break; + case GGML_TYPE_IQ4_XS_R8:ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS_R8;break; + case GGML_TYPE_Q4_0_R8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_R8; break; + case GGML_TYPE_Q5_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q5_0_R4; break; + case GGML_TYPE_Q6_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_0_R4; break; + case GGML_TYPE_Q8_0_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_0_R8; break; + case GGML_TYPE_MXFP4: ftype = LLAMA_FTYPE_MOSTLY_MXFP4; break; + case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; + case GGML_TYPE_IQ4_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS; break; + case GGML_TYPE_IQ4_KS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS_R4; break; + case GGML_TYPE_IQ5_KS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ5_KS_R4; break; + case GGML_TYPE_IQ4_KSS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KSS; break; + case GGML_TYPE_IQ5_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ5_KS; break; + case GGML_TYPE_IQ2_K: ftype = LLAMA_FTYPE_MOSTLY_IQ2_K; break; + case GGML_TYPE_IQ2_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_K_R4;break; + case GGML_TYPE_IQ3_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_KS; break; + case GGML_TYPE_IQ2_KL: ftype = LLAMA_FTYPE_MOSTLY_IQ2_KL; break; + case GGML_TYPE_IQ3_K: ftype = LLAMA_FTYPE_MOSTLY_IQ3_K; break; + case GGML_TYPE_IQ3_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ3_K_R4;break; + case GGML_TYPE_IQ4_K: ftype = LLAMA_FTYPE_MOSTLY_IQ4_K; break; + case GGML_TYPE_IQ4_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_K_R4;break; + case GGML_TYPE_IQ5_K: ftype = LLAMA_FTYPE_MOSTLY_IQ5_K; break; + case GGML_TYPE_IQ5_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ5_K_R4;break; + case GGML_TYPE_IQ6_K: ftype = LLAMA_FTYPE_MOSTLY_IQ6_K; break; + case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; + case GGML_TYPE_IQ3_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ3_S_R4;break; + case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break; + case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break; + case GGML_TYPE_Q4_0_8_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_8_8; break; + default: + { + LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); + ftype = LLAMA_FTYPE_ALL_F32; + } break; + } + + // this is a way to mark that we have "guessed" the file type + ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED); + + { + const int kid = gguf_find_key(meta, "general.file_type"); // TODO: use LLM_KV + if (kid >= 0) { + ftype = (llama_ftype) gguf_get_val_u32(meta, kid); + } + } + + LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__); + + for (int i = 0; i < n_kv; i++) { + const char * name = gguf_get_key(meta, i); + const enum gguf_type type = gguf_get_kv_type(meta, i); + const std::string type_name = + type == GGUF_TYPE_ARRAY + ? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta, i)), gguf_get_arr_n(meta, i)) + : gguf_type_name(type); + + std::string value = gguf_kv_to_str(meta, i); + const size_t MAX_VALUE_LEN = 40; + if (value.size() > MAX_VALUE_LEN) { + value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()); + } + replace_all(value, "\n", "\\n"); + + LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str()); + } + + // print type counts + for (auto & kv : n_type) { + if (kv.second == 0) { + continue; + } + + LLAMA_LOG_INFO("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second); + } + } + + if (!llama_mmap::SUPPORTED) { + LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__); + use_mmap = false; + } + if (repack_tensors) { + use_mmap = false; + } + + this->use_mmap = use_mmap; + this->check_tensors = check_tensors; + this->repack_tensors = repack_tensors; + this->use_thp = use_thp; +} + +llama_model_loader::~llama_model_loader() { + if (meta) { + gguf_free(meta); + } + for (auto * ctx : contexts) { + ggml_free(ctx); + } +} + +template +typename std::enable_if::value, bool>::type +llama_model_loader::get_arr_n(const std::string & key, T & result, const bool required) { + const int kid = gguf_find_key(meta, key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta, kid); + + + result = arr_info.length; + return true; +} + +template +typename std::enable_if::value, bool>::type +llama_model_loader::get_arr_n(const enum llm_kv kid, T & result, const bool required) { + return get_arr_n(llm_kv(kid), result, required); +} + +template +bool llama_model_loader::get_arr(const std::string & key, std::vector & result, const bool required) { + const int kid = gguf_find_key(meta, key.c_str()); + + if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) { + if (required) { + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta, kid); + + switch (arr_info.gt) { + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_INT32: GGML_ASSERT( + (std::is_same::value) || + (std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + } + + result.resize(arr_info.length); + result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); + + return true; +} + +template +bool llama_model_loader::get_arr(const std::string & key, std::array & result, const bool required) { + const int kid = gguf_find_key(meta, key.c_str()); + + if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) { + if (required) { + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta, kid); + + switch (arr_info.gt) { + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_INT32: GGML_ASSERT( + (std::is_same::value) || + (std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + } + + if (arr_info.length > N_MAX) { + throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX)); + } + + std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + + return true; +} + +template +bool llama_model_loader::get_arr(const enum llm_kv kid, T & result, const bool required) { + return get_arr(llm_kv(kid), result, required); +} + +template +bool llama_model_loader::get_key(const std::string & key, T & result, const bool required) { + auto it = kv_overrides.find(key); + + const struct llama_model_kv_override * override = + it != kv_overrides.end() ? &it->second : nullptr; + + const bool found = GGUFMeta::GKV::set(meta, key, result, override); + + if (required && !found) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + + return found; +} + +template +bool llama_model_loader::get_key(const enum llm_kv kid, T & result, const bool required) { + return get_key(llm_kv(kid), result, required); +} + +// get array of n <= N_MAX elements, or a single element repeated n times +template +bool llama_model_loader::get_key_or_arr(const std::string & key, std::array & result, uint32_t n, const bool required) { + const int kid = gguf_find_key(meta, key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + if (n > N_MAX) { + throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str())); + } + + if (gguf_get_kv_type(meta, kid) == GGUF_TYPE_ARRAY) { + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta, kid); + + if (n != arr_info.length) { + throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length)); + } + + return get_arr(key, result, required); + } else { + T value; + + bool ok = get_key(key, value, required); + if (!ok) { + return false; + } + + for (uint32_t i = 0; i < n; i++) { + result[i] = value; + } + + return true; + } +} + +template +bool llama_model_loader::get_key_or_arr(const enum llm_kv kid, T & result, uint32_t n, const bool required) { + return get_key_or_arr(llm_kv(kid), result, n, required); +} + +const char * llama_model_loader::get_tensor_name(int i) const { + return weights.at(i).tensor->name; +} + +const llama_model_loader::llama_tensor_weight * llama_model_loader::get_weight(const char * name) const { + for (const auto & weight : weights) { + if (strcmp(name, weight.tensor->name) == 0) { + return &weight; + } + } + return nullptr; +} + +const llama_model_loader::llama_tensor_weight & llama_model_loader::require_weight(const char * name) const { + const llama_tensor_weight * weight = get_weight(name); + if (!weight) { + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name)); + } + return *weight; +} + +struct ggml_tensor * llama_model_loader::get_tensor_meta(const char * name) const { + const auto * weight = get_weight(name); + if (!weight) { + return nullptr; + } + return weight->tensor; +} + +struct ggml_tensor * llama_model_loader::require_tensor_meta(const char * name) const { + struct ggml_tensor * tensor = get_tensor_meta(name); + if (!tensor) { + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name)); + } + return tensor; +} + +struct ggml_tensor * llama_model_loader::create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur, bool duplicated) { + struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur); + ggml_set_name(tensor, ggml_get_name(cur)); + + if (duplicated) { + size_data += ggml_nbytes(cur); + } else { + n_created++; + } + + return tensor; +} + +const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::string & name, const std::vector & ne, bool required) const { + const struct ggml_tensor * cur = get_tensor_meta(name.c_str()); + + if (cur == NULL) { + if (!required) { + return NULL; + } + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str())); + } + + { + bool is_ok = true; + for (size_t i = 0; i < GGML_MAX_DIMS; ++i) { + if ((i < ne.size() && ne[i] != cur->ne[i]) || (i >= ne.size() && cur->ne[i] != 1)) { + is_ok = false; + break; + } + } + if (!is_ok) { + throw std::runtime_error( + format("%s: tensor '%s' has wrong shape; expected %s, got %s", + __func__, name.c_str(), + llama_format_tensor_shape(ne).c_str(), + llama_format_tensor_shape(cur).c_str())); + } + } + + return cur; +} + +struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, + const std::vector & ne, int flags) { + const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED)); + + if (cur == NULL) { + return NULL; + } + + // skip unused tensors + if (flags & TENSOR_SKIP) { + const size_t nbytes = ggml_nbytes(cur); + LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", name.c_str(), nbytes); + + size_data -= nbytes; + n_created++; + + return nullptr; + } + + return create_tensor_for(ctx, cur, flags & TENSOR_DUPLICATED); +} + +struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, + const std::string & name, const std::vector & ne, size_t offset, bool required) { + const struct ggml_tensor * cur = check_tensor_dims(name, ne, required); + + if (cur == NULL) { + return NULL; + } + + if (cur->type != base->type) { + throw std::runtime_error(format("%s: tensor '%s' has wrong type; expected %s, got %s", __func__, name.c_str(), ggml_type_name(base->type), ggml_type_name(cur->type))); + } + + std::array dims; + for (size_t i = 0; i < GGML_MAX_DIMS; ++i) { + dims[i] = i < ne.size() ? ne[i] : 1; + } + + struct ggml_tensor * tensor = ggml_view_4d(ctx, base, + dims[0], dims[1], dims[2], dims[3], + cur->nb[1], cur->nb[2], cur->nb[3], + offset); + + ggml_set_name(tensor, name.c_str()); + + n_created++; + + return tensor; +} + +void llama_model_loader::done_getting_tensors() const { + if (n_created != n_tensors) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } +} + +void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps, bool use_thp) { + if (use_mmap) { + mappings.reserve(files.size()); + mmaps_used.reserve(files.size()); + for (const auto & file : files) { + std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, ggml_is_numa(), use_thp)); + mmaps_used.emplace_back(mapping->size(), 0); + if (mlock_mmaps) { + std::unique_ptr mlock_mmap(new llama_mlock()); + mlock_mmap->init(mapping->addr()); + mlock_mmaps->emplace_back(std::move(mlock_mmap)); + } + mappings.emplace_back(std::move(mapping)); + } + } + + // compute the total size of all tensors for progress reporting + for (auto & w : weights) { + size_data += ggml_nbytes(w.tensor); + } +} + +void llama_model_loader::get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const { + GGML_ASSERT(!mappings.empty()); + const auto & mapping = mappings.at(idx); + + *first = mapping->size(); + *last = 0; + *addr = mapping->addr(); + for (ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor; tensor = ggml_get_next_tensor(ctx, tensor)) { + try { + const auto * weight = get_weight(ggml_get_name(tensor)); + if (!weight) { + continue; + } + if (weight->idx != idx) { + continue; + } + *first = std::min(*first, weight->offs); + *last = std::max(*last, weight->offs + ggml_nbytes(tensor)); + } catch(...) { + // the tensor is not in the model + } + } +} + +// for backwards compatibility, does not support ggml-backend +void llama_model_loader::load_data_for(struct ggml_tensor * cur) const { + const auto & w = require_weight(ggml_get_name(cur)); + + if (use_mmap) { + const auto & mapping = mappings.at(w.idx); + if (cur->data == nullptr) { + cur->data = (uint8_t *)mapping->addr() + w.offs; + } else { + memcpy(cur->data, (uint8_t *)mapping->addr() + w.offs, ggml_nbytes(cur)); + } + } else { + GGML_ASSERT(cur->data != nullptr); + GGML_ASSERT(w.idx < files.size()); + const auto & file = files.at(w.idx); + file->seek(w.offs, SEEK_SET); + file->read_raw(cur->data, ggml_nbytes(cur)); + } + + if (check_tensors && !ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) { + throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + } +} + +// Returns false if cancelled by progress_callback +bool llama_model_loader::load_all_data( + struct ggml_context * ctx, + llama_buf_map & bufs_mmap, + llama_mlocks * lmlocks, + llama_progress_callback progress_callback, + void * progress_callback_user_data) { + GGML_ASSERT(size_data != 0 && "call init_mappings() first"); + + std::vector> read_buf; + std::vector>> validation_result; + +#if defined(GGML_USE_CUDA) + // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives. + // NVMe raid configurations might require more / larger buffers. + constexpr size_t n_buffers = 4; + constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB + + std::vector host_buffers; + std::vector host_ptrs; + std::vector events; + size_t buffer_idx = 0; // buffer to use for async loads + + ggml_backend_t cuda_backend = nullptr; + if (!use_mmap && !check_tensors) { + // When not using mmaped io use async uploads from pinned memory to GPU memory. + // First determine if the CUDA backend is active, and if so, determine the device ID. + ggml_backend_buffer_t buf = bufs_mmap.count(0) ? bufs_mmap.at(0) : nullptr; + if (buf) { + ggml_backend_buffer_type_t buffer_type = ggml_backend_buffer_get_type(buf); + for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) { + auto * cuda_buffer_type = ggml_backend_cuda_buffer_type(i); + if (buffer_type == cuda_buffer_type) { + cuda_backend = ggml_backend_cuda_init(i); + break; + } + } + } + + // If the cuda backend is active create pinned memory buffers and events for synchronisation. + if (cuda_backend) { + for (size_t idx = 0; idx < n_buffers; ++idx) { + host_buffers.emplace_back(ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buffer_size)); + host_ptrs.emplace_back(ggml_backend_buffer_get_base(host_buffers[idx])); + events.emplace_back(ggml_backend_event_new(cuda_backend)); + } + } + } +#endif + + for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) { + const auto * weight = get_weight(ggml_get_name(cur)); + if (weight == nullptr) { + // this can happen with split experts models + continue; + } + + if (progress_callback) { + if (!progress_callback((float) size_done / size_data, progress_callback_user_data)) { + return false; + } + } + + size_t n_size = ggml_nbytes(cur); + + if (use_mmap) { + const auto & mapping = mappings.at(weight->idx); + ggml_backend_buffer_t buf_mmap = nullptr; + if (bufs_mmap.count(weight->idx)) { + buf_mmap = bufs_mmap.at(weight->idx); + } + uint8_t * data = (uint8_t *) mapping->addr() + weight->offs; + + if (check_tensors) { + validation_result.emplace_back(std::async(std::launch::async, [cur, data, n_size] { + return std::make_pair(cur, ggml_validate_row_data(cur->type, data, n_size)); + })); + } + + GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated + if (buf_mmap && cur->data == nullptr) { + ggml_backend_tensor_alloc(buf_mmap, cur, data); + if (lmlocks) { + const auto & lmlock = lmlocks->at(weight->idx); + lmlock->grow_to(weight->offs + n_size); + } + + auto & mmap_used = mmaps_used[weight->idx]; + mmap_used.first = std::min(mmap_used.first, weight->offs); + mmap_used.second = std::max(mmap_used.second, weight->offs + n_size); + } else { + ggml_backend_tensor_set(cur, data, 0, n_size); + } + } else { + GGML_ASSERT(weight->idx < files.size()); + const auto & file = files.at(weight->idx); + if (ggml_backend_buffer_is_host(cur->buffer)) { + file->seek(weight->offs, SEEK_SET); + file->read_raw(cur->data, n_size); + if (check_tensors) { + validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] { + return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size)); + })); + } + } else { +#if defined(GGML_USE_CUDA) + // If cuda_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU. + if (cuda_backend) { + file->seek(weight->offs, SEEK_SET); + + size_t bytes_read = 0; + + while (bytes_read < n_size) { + size_t read_iteration = std::min(buffer_size, n_size - bytes_read); + + ggml_backend_event_synchronize(events[buffer_idx]); + file->read_raw(host_ptrs[buffer_idx], read_iteration); + ggml_backend_tensor_set_async(cuda_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration); + ggml_backend_event_record(events[buffer_idx]); + + bytes_read += read_iteration; + ++buffer_idx; + buffer_idx %= n_buffers; + } + } + else +#endif + { + read_buf.resize(n_size); + file->seek(weight->offs, SEEK_SET); + file->read_raw(read_buf.data(), n_size); + ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); + if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { + throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + } + } + } + } + + size_done += n_size; + } + +#if defined(GGML_USE_CUDA) + // free temporary resources used for async cuda uploads + if (cuda_backend) { + for (size_t idx = 0; idx < n_buffers;++idx) { + ggml_backend_event_synchronize(events[idx]); + ggml_backend_event_free(events[idx]); + ggml_backend_buffer_free(host_buffers[idx]); + } + ggml_backend_free(cuda_backend); + } +#endif + + // check validation results + bool validation_failed = false; + for (auto & future : validation_result) { + auto result = future.get(); + if (!result.second) { + LLAMA_LOG_ERROR("%s: tensor '%s' has invalid data\n", __func__, ggml_get_name(result.first)); + validation_failed = true; + } + } + if (validation_failed) { + throw std::runtime_error("found tensors with invalid data"); + } + + // check if this is the last call and do final cleanup + if (size_done >= size_data) { + // unmap offloaded tensors and metadata + if (use_mmap) { + for (uint32_t idx = 0; idx < mappings.size(); idx++) { + const auto & mmap_used = mmaps_used.at(idx); + auto & mapping = mappings.at(idx); + mapping->unmap_fragment(0, mmap_used.first); + if (mmap_used.second != 0) { + mapping->unmap_fragment(mmap_used.second, mapping->size()); + } + } + } + if (progress_callback) { + // Even though the model is done loading, we still honor + // cancellation since we need to free allocations. + return progress_callback(1.0f, progress_callback_user_data); + } + } + + return true; +} + +template<> +bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) { + uint32_t tmp; + const bool found = get_key(kid, tmp, required); + if (found) { + result = (enum llama_pooling_type) tmp; + } else { + result = LLAMA_POOLING_TYPE_UNSPECIFIED; + } + return found; +} +template bool llama_model_loader::get_key (enum llm_kv kid, bool & result, bool required); +template bool llama_model_loader::get_key (enum llm_kv kid, float & result, bool required); +template bool llama_model_loader::get_key (enum llm_kv kid, uint32_t & result, bool required); +template bool llama_model_loader::get_key(enum llm_kv kid, std::string & result, bool required); + +template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); +template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + +template std::enable_if::value, bool>::type llama_model_loader::get_arr_n(enum llm_kv, unsigned int&, bool); diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h new file mode 100644 index 000000000..4240b0d18 --- /dev/null +++ b/src/llama-model-loader.h @@ -0,0 +1,169 @@ +#pragma once + +#include "llama.h" +#include "llama-impl.h" +#include "llama-mmap.h" +#include "llama-arch.h" + +#include +#include +#include +#include +#include + +enum llama_fver { + GGUF_FILE_VERSION_V1 = 1, + GGUF_FILE_VERSION_V2 = 2, + GGUF_FILE_VERSION_V3 = 3, +}; + +static const char * llama_file_version_name(llama_fver version) { + switch (version) { + case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)"; + case GGUF_FILE_VERSION_V2: return "GGUF V2"; + case GGUF_FILE_VERSION_V3: return "GGUF V3 (latest)"; + } + + return "unknown"; +} + +using llama_buf_map = std::unordered_map; + +struct llama_model_loader { + int n_kv = 0; + int n_tensors = 0; + int n_created = 0; + + int64_t n_elements = 0; + size_t n_bytes = 0; + + bool use_mmap = false; + bool check_tensors; + bool repack_tensors = false; + bool use_thp = false; + + llama_files files; + llama_ftype ftype; + llama_fver fver; + + llama_mmaps mappings; + + // Holds information on a model weight + struct llama_tensor_weight { + uint16_t idx; // source file index + size_t offs; // tensor data offset in the original file + + ggml_tensor * tensor; + + llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) { + const int tensor_idx = gguf_find_tensor(gguf_ctx, name); + offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx); + + if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size()) { + throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name)); + } + } + }; + std::vector weights; + + std::unordered_map kv_overrides; + const llama_model_tensor_buft_override * tensor_buft_overrides; + + gguf_context * meta = NULL; + std::vector contexts; + + std::string arch_name; + LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); + + llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp, + const llama_model_kv_override * param_overrides_p, + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p); + + ~llama_model_loader(); + + template + typename std::enable_if::value, bool>::type + get_arr_n(const std::string & key, T & result, const bool required = true); + + template + typename std::enable_if::value, bool>::type + get_arr_n(const enum llm_kv kid, T & result, const bool required = true); + + template + bool get_arr(const std::string & key, std::vector & result, const bool required = true); + + template + bool get_arr(const std::string & key, std::array & result, const bool required = true); + + template + bool get_arr(const enum llm_kv kid, T & result, const bool required = true); + + template + bool get_key(const std::string & key, T & result, const bool required = true); + + template + bool get_key(const enum llm_kv kid, T & result, const bool required = true); + + // get array of n <= N_MAX elements, or a single element repeated n times + template + bool get_key_or_arr(const std::string & key, std::array & result, uint32_t n, const bool required = true); + + template + bool get_key_or_arr(const enum llm_kv kid, T & result, uint32_t n, const bool required = true); + + const std::string& get_arch_name() const { return arch_name; } + + enum llm_arch get_arch() const { return llm_kv.arch; } + + const char * get_tensor_name(int i) const; + + const llama_tensor_weight * get_weight(const char * name) const; + + const llama_tensor_weight * get_weight(int i) const { + return get_weight(get_tensor_name(i)); + } + + const llama_tensor_weight & require_weight(const char * name) const; + + struct ggml_tensor * get_tensor_meta(const char * name) const; + + struct ggml_tensor * require_tensor_meta(const char * name) const; + + struct ggml_tensor * get_tensor_meta(int i) const { + return get_tensor_meta(get_tensor_name(i)); + } + + struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur, bool duplicated); + + const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector & ne, bool required) const; + + static const int TENSOR_NOT_REQUIRED = 1 << 0; + static const int TENSOR_DUPLICATED = 1 << 1; + static const int TENSOR_SKIP = 1 << 2; + + struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, int flags = 0); + + struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, + const std::string & name, const std::vector & ne, size_t offset, bool required = true); + + void done_getting_tensors() const; + + void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr, bool use_thp = false); + + void get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const; + + // for backwards compatibility, does not support ggml-backend + void load_data_for(struct ggml_tensor * cur) const; + + size_t size_done = 0; + size_t size_data = 0; + std::vector> mmaps_used; + + // Returns false if cancelled by progress_callback + bool load_all_data( + struct ggml_context * ctx, + llama_buf_map & bufs_mmap, + llama_mlocks * lmlocks, + llama_progress_callback progress_callback, + void * progress_callback_user_data); +}; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 40d9963dc..0d806d8a8 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -734,7 +734,7 @@ llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_da // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) static void get_overlapping_token_sequences(const llama_vocab& vocab, const std::string& str, std::unordered_multimap>& token_sequences, int max_tail_len = -1) { for (llama_token token_id = 0; token_id < (llama_token)vocab.n_tokens(); token_id++) { - std::string word = llama_detokenize(vocab, { token_id }, true); + auto word = vocab.detokenize( { token_id }, true); if (word.find(str) != std::string::npos) { token_sequences.emplace(token_id, std::vector()); } @@ -751,7 +751,8 @@ static void get_overlapping_token_sequences(const llama_vocab& vocab, const std: } } if (match) { - std::vector tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false); + auto tokenization = vocab.tokenize(str.substr(i), false, false); + //std::vector tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false); if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) { tokenization.resize(max_tail_len); } diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a2bc72d9b..b92adfd4c 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1,37 +1,29 @@ #include "llama-vocab.h" +#include "ggml.h" +#include "llama-impl.h" +#include "llama-model-loader.h" + #include "unicode.h" #include #include +#include #include -#include +#include #include #include #include +#include +#include #include -#include +#include +#include // // helpers // -LLAMA_ATTRIBUTE_FORMAT(1, 2) -static std::string format(const char * fmt, ...) { - va_list ap; - va_list ap2; - va_start(ap, fmt); - va_copy(ap2, ap); - int size = vsnprintf(NULL, 0, fmt, ap); - GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT - std::vector buf(size + 1); - int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); - GGML_ASSERT(size2 == size); - va_end(ap2); - va_end(ap); - return std::string(buf.data(), size); -} - struct naive_trie { naive_trie() : has_value(false), value(0) { } @@ -50,7 +42,7 @@ struct naive_trie { res.first->second.insert(key + 1, len - 1, value); } } - std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) { + std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) const { if (len == 0 || offset == len) { return std::make_pair(key, offset); } @@ -58,107 +50,31 @@ struct naive_trie { auto res = children.find(c); if (res != children.end()) { return res->second.get_longest_prefix(key, len, offset + 1); - } else { - return std::make_pair(key, offset); } + + return std::make_pair(key, offset); } - struct naive_trie * traverse(const char c) { + const struct naive_trie * traverse(const char c) const { auto res = children.find(c); if (res != children.end()) { return &res->second; - } else { - return NULL; } + + return NULL; } std::map children; bool has_value; llama_token value; }; -uint32_t llama_vocab::n_tokens() const { - return (uint32_t)id_to_token.size(); -} // -// impl +// tokenizers // -int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const { - GGML_ASSERT(token_left.find(' ') == std::string::npos); - GGML_ASSERT(token_left.find('\n') == std::string::npos); - GGML_ASSERT(token_right.find(' ') == std::string::npos); - GGML_ASSERT(token_right.find('\n') == std::string::npos); - - auto it = bpe_ranks.find(std::make_pair(token_left, token_right)); - if (it == bpe_ranks.end()) { - return -1; - } - - return it->second; -} - -static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { - return vocab.type; -} - -static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL; -} - -static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN; -} - -static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL; -} - -static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE; -} - -static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED; -} - -static bool llama_is_unused_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED; -} - -static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE); - GGML_ASSERT(llama_is_byte_token(vocab, id)); - const auto & token_data = vocab.id_to_token.at(id); - switch (llama_vocab_get_type(vocab)) { - case LLAMA_VOCAB_TYPE_SPM: - case LLAMA_VOCAB_TYPE_UGM: { - auto buf = token_data.text.substr(3, 2); - return strtol(buf.c_str(), NULL, 16); - } - case LLAMA_VOCAB_TYPE_BPE: { - GGML_ABORT("fatal error"); - //return unicode_utf8_to_byte(token_data.text); // TODO: why is this here after GGML_ASSERT? - } - case LLAMA_VOCAB_TYPE_WPM: { - GGML_ABORT("fatal error"); - } - default: - GGML_ABORT("fatal error"); - } -} - -static void llama_escape_whitespace(std::string & text) { - replace_all(text, " ", "\xe2\x96\x81"); -} - -static void llama_unescape_whitespace(std::string & word) { - replace_all(word, "\xe2\x96\x81", " "); -} +struct llm_tokenizer { + llm_tokenizer() {} + virtual ~llm_tokenizer() = default; +}; struct llm_symbol { using index = int; @@ -190,10 +106,14 @@ struct llm_bigram_spm { size_t size; }; -struct llm_tokenizer_spm { - llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {} +struct llm_tokenizer_spm : llm_tokenizer { + llm_tokenizer_spm(const llama_vocab & /*vocab*/) {} +}; + +struct llm_tokenizer_spm_session { + llm_tokenizer_spm_session(const llama_vocab & vocab) : vocab(vocab) {} - void tokenize(const std::string & text, std::vector & output) { + void tokenize(const std::string & text, std::vector & output) { // split string into utf8 chars int index = 0; size_t offs = 0; @@ -210,7 +130,7 @@ struct llm_tokenizer_spm { } // seed the work queue with all possible 2-character tokens. - for (size_t i = 1; i < symbols.size(); ++i) { + for (int i = 1; i < (int) symbols.size(); ++i) { try_add_bigram(i - 1, i); } @@ -252,13 +172,13 @@ struct llm_tokenizer_spm { } private: - void resegment(llm_symbol & symbol, std::vector & output) { + void resegment(llm_symbol & symbol, std::vector & output) { auto text = std::string(symbol.text, symbol.n); - auto token = vocab.token_to_id.find(text); + auto token = vocab.text_to_token(text); // Do we need to support is_unused? - if (token != vocab.token_to_id.end()) { - output.push_back((*token).second); + if (token != LLAMA_TOKEN_NULL) { + output.push_back(token); return; } @@ -268,13 +188,13 @@ struct llm_tokenizer_spm { // output any symbols that did not form tokens as bytes. output.reserve(output.size() + symbol.n); for (int j = 0; j < (int)symbol.n; ++j) { - llama_vocab::id token_id = llama_byte_to_token_impl(vocab, symbol.text[j]); - output.push_back(token_id); + llama_token id = vocab.byte_to_token(symbol.text[j]); + output.push_back(id); } return; } - resegment(symbols[p->second.first], output); + resegment(symbols[p->second.first], output); resegment(symbols[p->second.second], output); } @@ -282,19 +202,18 @@ struct llm_tokenizer_spm { if (left == -1 || right == -1) { return; } - const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n); - auto token = vocab.token_to_id.find(text); + auto token = vocab.text_to_token(text); - if (token == vocab.token_to_id.end()) { + if (token == LLAMA_TOKEN_NULL) { return; } - if (static_cast((*token).second) >= vocab.id_to_token.size()) { + if (static_cast(token) >= vocab.n_tokens()) { return; } - const auto & tok_data = vocab.id_to_token[(*token).second]; + const auto & tok_data = vocab.get_token_data(token); llm_bigram_spm bigram; bigram.left = left; @@ -309,10 +228,11 @@ struct llm_tokenizer_spm { } const llama_vocab & vocab; + // currently unused + // const llm_tokenizer_spm * spm_tokenizer; std::vector symbols; llm_bigram_spm::queue work_queue; - std::map> rev_merge; }; @@ -324,6 +244,21 @@ struct llm_tokenizer_spm { // TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused +template, typename Compare = std::less> +class llama_priority_queue : public std::priority_queue { +public: + using std::priority_queue::priority_queue; + + T pop_move() { + T item = std::move(this->c.front()); + std::pop_heap(this->c.begin(), this->c.end(), this->comp); + this->c.pop_back(); + return item; + } + + void pop() = delete; +}; + struct llm_bigram_bpe { struct comparator { bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const { @@ -332,7 +267,7 @@ struct llm_bigram_bpe { }; using queue_storage = std::vector; - using queue = std::priority_queue; + using queue = llama_priority_queue; llm_symbol::index left; llm_symbol::index right; std::string text; @@ -340,10 +275,10 @@ struct llm_bigram_bpe { size_t size; }; -struct llm_tokenizer_bpe { - llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) { - GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE); - switch (vocab.type_pre) { +struct llm_tokenizer_bpe : llm_tokenizer { + llm_tokenizer_bpe(const llama_vocab & vocab) { + GGML_ASSERT(vocab.get_type() == LLAMA_VOCAB_TYPE_BPE); + switch (vocab.get_pre_type()) { case LLAMA_VOCAB_PRE_TYPE_LLAMA3: regex_exprs = { // original regex from tokenizer.json @@ -371,6 +306,7 @@ struct llm_tokenizer_bpe { }; break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: + case LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE: regex_exprs = { "\\p{N}{1,3}", "[一-龥぀-ゟ゠-ヿ]+", @@ -393,25 +329,13 @@ struct llm_tokenizer_bpe { "[0-9][0-9][0-9]", }; break; - case LLAMA_VOCAB_PRE_TYPE_FALCON_3: - regex_exprs = { - "[\\p{P}\\$\\+<=>\\^~\\|`]+", - "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", - "[0-9]", - }; - break; - case LLAMA_VOCAB_PRE_TYPE_FALCON_E: - regex_exprs = { - "[\\p{P}\\$\\+<=>\\^~\\|`]+", - "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", - "[0-9]", - }; - break; case LLAMA_VOCAB_PRE_TYPE_STARCODER: case LLAMA_VOCAB_PRE_TYPE_REFACT: case LLAMA_VOCAB_PRE_TYPE_COMMAND_R: case LLAMA_VOCAB_PRE_TYPE_SMOLLM: case LLAMA_VOCAB_PRE_TYPE_CODESHELL: + case LLAMA_VOCAB_PRE_TYPE_EXAONE: + case LLAMA_VOCAB_PRE_TYPE_MINERVA: regex_exprs = { "\\p{N}", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", @@ -421,6 +345,7 @@ struct llm_tokenizer_bpe { case LLAMA_VOCAB_PRE_TYPE_MPT: case LLAMA_VOCAB_PRE_TYPE_OLMO: case LLAMA_VOCAB_PRE_TYPE_JAIS: + case LLAMA_VOCAB_PRE_TYPE_TRILLION: regex_exprs = { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }; @@ -435,6 +360,8 @@ struct llm_tokenizer_bpe { }; break; case LLAMA_VOCAB_PRE_TYPE_PORO: + case LLAMA_VOCAB_PRE_TYPE_BLOOM: + case LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH: regex_exprs = { " ?[^(\\s|.,!?…。,、।۔،)]+", }; @@ -457,6 +384,20 @@ struct llm_tokenizer_bpe { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_CHAMELEON: + // Note: in theory, the special token (sentinel and image token) regex_exprs below + // are unnecessary, as they are split in `tokenizer_st_partition` anyway. + // However, since the upstream pre-tokenizer uses them, they are also + // included here (see https://huggingface.co/facebook/chameleon-7b). + regex_exprs = { + "", // Sentinel tokens + "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens + "([\\t\\n]| | )", // directly from tokenizer.json + "\\p{N}", // Individual digits + "[\\p{P}!-/:-@\\[-`{-~]", // Punctuation, Isolated + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }; + break; case LLAMA_VOCAB_PRE_TYPE_GPT4O: regex_exprs = { // original regex from tokenizer.json @@ -485,7 +426,7 @@ struct llm_tokenizer_bpe { "'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", }; break; - case LLAMA_VOCAB_PRE_TYPE_SEED_CODER: + case LLAMA_VOCAB_PRE_TYPE_SEED_CODER: regex_exprs = { // original regex from tokenizer.json // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\r\n]+|\\s*[\r\n]+|\\s+(?!\\S)|\\s+" @@ -504,36 +445,42 @@ struct llm_tokenizer_bpe { } } - void append(const llama_vocab::id token_id, std::vector & output) const { + std::vector regex_exprs; +}; + +struct llm_tokenizer_bpe_session { + llm_tokenizer_bpe_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + + static void append(const llama_token token_id, std::vector & output) { output.push_back(token_id); } - bool append_bos(std::vector & output) const { - if (vocab.tokenizer_add_bos) { - GGML_ASSERT(vocab.special_bos_id != -1); - output.push_back(vocab.special_bos_id); + bool append_bos(std::vector & output) const { + if (vocab.get_add_bos()) { + GGML_ASSERT(vocab.token_bos() != LLAMA_TOKEN_NULL); + output.push_back(vocab.token_bos()); return true; } return false; } - bool append_eos(std::vector & output) const { - if (vocab.tokenizer_add_eos) { - GGML_ASSERT(vocab.special_eos_id != -1); - output.push_back(vocab.special_eos_id); + bool append_eos(std::vector & output) const { + if (vocab.get_add_eos()) { + GGML_ASSERT(vocab.token_eos() != LLAMA_TOKEN_NULL); + output.push_back(vocab.token_eos()); return true; } return false; } - void check_double_bos_eos(const std::vector & output) const { - if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) { + void check_double_bos_eos(const std::vector & output) const { + if (vocab.get_add_bos() && output.size() >= 2 && output[1] == vocab.token_bos()) { LLAMA_LOG_WARN( "%s: Added a BOS token to the prompt as specified by the model but the prompt " "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " "Are you sure this is what you want?\n", __FUNCTION__); } - if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) { + if (vocab.get_add_eos() && output.size() >= 2 && *(output.end()-2) == vocab.token_eos()) { LLAMA_LOG_WARN( "%s: Added a EOS token to the prompt as specified by the model but the prompt " "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. " @@ -541,21 +488,21 @@ struct llm_tokenizer_bpe { } } - void tokenize(const std::string & text, std::vector & output) { + void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; - - const auto word_collection = unicode_regex_split(text, regex_exprs); + const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs); symbols_final.clear(); - for (auto & word : word_collection) { + for (const auto & word : word_collection) { work_queue = llm_bigram_bpe::queue(); symbols.clear(); int index = 0; size_t offset = 0; - if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) { + //if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) { + if (vocab.get_ignore_merges() && vocab.text_to_token(word) != LLAMA_TOKEN_NULL) { symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); offset = word.size(); } @@ -571,14 +518,13 @@ struct llm_tokenizer_bpe { index++; symbols.emplace_back(sym); } - for (size_t i = 1; i < symbols.size(); ++i) { + for (int i = 1; i < (int) symbols.size(); ++i) { add_new_bigram(i - 1, i); } // build token(s) while (!work_queue.empty()) { - auto bigram = work_queue.top(); - work_queue.pop(); + auto bigram = work_queue.pop_move(); auto & left_symbol = symbols[bigram.left]; auto & right_symbol = symbols[bigram.right]; @@ -630,18 +576,18 @@ struct llm_tokenizer_bpe { } const std::string str = std::string(symbol.text, symbol.n); - const auto token = vocab.token_to_id.find(str); + const auto token = vocab.text_to_token(str); - if (token == vocab.token_to_id.end()) { + if (token == LLAMA_TOKEN_NULL) { for (auto j = str.begin(); j != str.end(); ++j) { std::string byte_str(1, *j); - auto token_multibyte = vocab.token_to_id.find(byte_str); - if (token_multibyte != vocab.token_to_id.end()) { - output.push_back(token_multibyte->second); + auto token_multibyte = vocab.text_to_token(byte_str); + if (token_multibyte != LLAMA_TOKEN_NULL) { + output.push_back(token_multibyte); } } } else { - output.push_back((*token).second); + output.push_back(token); } } } @@ -652,7 +598,6 @@ struct llm_tokenizer_bpe { if (left == -1 || right == -1) { return; } - std::string left_token = std::string(symbols[left].text, symbols[left].n); std::string right_token = std::string(symbols[right].text, symbols[right].n); @@ -676,12 +621,10 @@ struct llm_tokenizer_bpe { } const llama_vocab & vocab; - - std::vector regex_exprs; + const llm_tokenizer_bpe & tokenizer; std::vector symbols; std::vector symbols_final; - llm_bigram_bpe::queue work_queue; }; @@ -689,15 +632,16 @@ struct llm_tokenizer_bpe { // WPM tokenizer // -struct llm_tokenizer_wpm { - llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {} +struct llm_tokenizer_wpm : llm_tokenizer { + llm_tokenizer_wpm(const llama_vocab & /*vocab*/) {} +}; - void tokenize(const std::string & text, std::vector & output) const { - const auto & token_map = vocab.token_to_id; +struct llm_tokenizer_wpm_session { + llm_tokenizer_wpm_session(const llama_vocab & vocab) : vocab(vocab) {} + void tokenize(const std::string & text, std::vector & output) { // normalize and split by whitespace std::vector words = preprocess(text); - // bos token prepended already // find the longest tokens that form the words @@ -718,10 +662,10 @@ struct llm_tokenizer_wpm { for (int i = 0; i < n; ++i) { // loop through possible match length bool match = false; - for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) { - auto it = token_map.find(word1.substr(i, j - i)); - if (it != token_map.end()) { - output.push_back(it->second); + for (int j = std::min(n, i + vocab.max_token_len() + 1); j > i; j--) { + auto id = vocab.text_to_token(word1.substr(i, j - i)); + if (id != LLAMA_TOKEN_NULL) { + output.push_back(id); match = true; i = j - 1; break; @@ -736,18 +680,18 @@ struct llm_tokenizer_wpm { // we didn't find any matches for this word if (current_tokens == output.size()) { - output.push_back(vocab.special_unk_id); + output.push_back(vocab.token_unk()); } } } // TODO: reduce string copies by using cpts_offs array - std::vector preprocess(const std::string & text) const { + static std::vector preprocess(const std::string & text) { const std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); std::vector words(1, ""); for (const uint32_t cpt : cpts_nfd) { - const auto flags = unicode_cpt_flags(cpt); + const auto flags = unicode_cpt_flags_from_cpt(cpt); if (flags.is_whitespace) { if (words.back().size()) { // finish previous word if any @@ -794,53 +738,56 @@ struct llm_tokenizer_wpm { //(cpt >= 0xFF00 && cpt <= 0xFFEF); } +private: const llama_vocab & vocab; + // currently unused + // const llm_tokenizer_wpm * wpm_tokenizer; }; // // UGM tokenizer // -struct llm_tokenizer_ugm { - llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) { - if (vocab.precompiled_charsmap.size() > 0) { +struct llm_tokenizer_ugm : llm_tokenizer { + llm_tokenizer_ugm(const llama_vocab & vocab, const std::vector & precompiled_charsmap) { + if (precompiled_charsmap.size() > 0) { size_t charsmap_offset = 0; // First four bytes of precompiled_charsmap contains length of binary // blob containing XOR-compressed compact double array (XCDA) entries - uint32_t xcda_blob_size = *(const uint32_t *) &vocab.precompiled_charsmap[0]; + uint32_t xcda_blob_size = *(const uint32_t *) &precompiled_charsmap[0]; charsmap_offset += sizeof(xcda_blob_size); - if (xcda_blob_size + charsmap_offset >= vocab.precompiled_charsmap.size()) { + if (xcda_blob_size + charsmap_offset >= precompiled_charsmap.size()) { throw std::runtime_error("Index out of array bounds in precompiled charsmap!"); } // Next xcda_blob_size bytes contain entries of XOR-compressed compact // double array (XCDA). Each entry is bit-packed into a 32-bit integer. - xcda_array = (const uint32_t *) &vocab.precompiled_charsmap[charsmap_offset]; + xcda_array = (const uint32_t *) &precompiled_charsmap[charsmap_offset]; xcda_array_size = xcda_blob_size / sizeof(uint32_t); charsmap_offset += xcda_blob_size; // Remaining bytes of precompiled charsmap contain null-terminated // replacement strings for prefixes matched by the XCDA. - prefix_replacements = &vocab.precompiled_charsmap[charsmap_offset]; - prefix_replacements_size = vocab.precompiled_charsmap.size() - charsmap_offset; + prefix_replacements = &precompiled_charsmap[charsmap_offset]; + prefix_replacements_size = precompiled_charsmap.size() - charsmap_offset; } - for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) { - const auto &token_data = vocab.id_to_token[id]; + for (uint32_t id = 0; id < vocab.n_tokens(); ++id) { + const auto & token_data = vocab.get_token_data(id); - if (llama_is_normal_token(vocab, id)) { + if (vocab.is_normal(id)) { min_score = std::min(min_score, token_data.score); max_score = std::max(max_score, token_data.score); } - if (llama_is_normal_token(vocab, id) || - llama_is_user_defined_token(vocab, id) || - llama_is_unused_token(vocab, id)) { + if (vocab.is_normal(id) || + vocab.is_user_defined(id) || + vocab.is_unused(id)) { token_matcher.insert(token_data.text.data(), token_data.text.size(), id); } - if (llama_is_user_defined_token(vocab, id)) { + if (vocab.is_user_defined(id)) { user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size()); } } @@ -848,6 +795,29 @@ struct llm_tokenizer_ugm { unknown_token_score = min_score - unknown_token_score_penalty; } + // escaped space symbol - U+2581 (Lower One Eighth Block) + const std::string escaped_space = "\xE2\x96\x81"; + + const char * prefix_replacements = NULL; + size_t prefix_replacements_size = 0; + + const uint32_t * xcda_array = NULL; + size_t xcda_array_size = 0; + + struct naive_trie user_defined_token_matcher; + + float min_score = FLT_MAX; + float max_score = -FLT_MAX; + + float unknown_token_score_penalty = 10.0; + float unknown_token_score; + + struct naive_trie token_matcher; +}; + +struct llm_tokenizer_ugm_session { + llm_tokenizer_ugm_session(const llama_vocab & vocab, const llm_tokenizer_ugm & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + /* This implementation is based on SentencePiece optimized Viterbi algorithm for * unigram language models. The general idea is to: * - move along the input sequence in steps of one UTF code point, @@ -861,7 +831,7 @@ struct llm_tokenizer_ugm { * After processing the whole sequence we backtrack from the end to get * the best tokenization. */ - void tokenize(const std::string & text, std::vector & output) { + void tokenize(const std::string & text, std::vector & output) { // get current size of output (for reversal later) size_t output_size = output.size(); @@ -874,9 +844,9 @@ struct llm_tokenizer_ugm { } // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores - std::vector tokenization_results(input_len + 1, {vocab.special_unk_id, 0, -FLT_MAX}); + std::vector tokenization_results(input_len + 1, {vocab.token_unk(), 0, -DBL_MAX}); // at the beginning tokenization score is zero - tokenization_results[0] = { vocab.special_unk_id, 0, 0 }; + tokenization_results[0] = { vocab.token_unk(), 0, 0 }; for (size_t input_offset = 0; input_offset < input_len;) { size_t prefix_offset = input_offset; @@ -886,7 +856,7 @@ struct llm_tokenizer_ugm { // traverse the token matcher trie to find a matching token bool single_codepoint_token_found = false; const struct best_tokenization & current_best = tokenization_results[input_offset]; - struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]); + const struct naive_trie * node = tokenizer.token_matcher.traverse(normalized[prefix_offset++]); while (prefix_offset <= input_len && node != NULL) { // check if we found valid token in prefix @@ -896,17 +866,17 @@ struct llm_tokenizer_ugm { single_codepoint_token_found = true; } llama_token token_id = node->value; - const auto & token_data = vocab.id_to_token[token_id]; + const auto & token_data = vocab.get_token_data(token_id); // we set the user-defined token scores to 0 to make them more likely to be selected // (normal token scores are log probabilities, so they are negative) // score type is double here to make tokenization results exactly // the same as in the HF tokenizer using SentencePiece - const double token_score = llama_is_user_defined_token(vocab, token_id) ? 0.0 : token_data.score; + const double token_score = vocab.is_user_defined(token_id) ? 0.0 : token_data.score; const double challenger_score = current_best.score_sum + token_score; struct best_tokenization & current_champ = tokenization_results[prefix_offset]; if (challenger_score > current_champ.score_sum) { - struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score }; + struct best_tokenization challenger = { token_id, input_offset, challenger_score }; current_champ = challenger; } } @@ -916,11 +886,11 @@ struct llm_tokenizer_ugm { // if we didn't find a valid token corresponding to the whole UTF code point // then use unknown token as the tokenization of this UTF code point if (!single_codepoint_token_found) { - const double challenger_score = current_best.score_sum + unknown_token_score; + const double challenger_score = current_best.score_sum + tokenizer.unknown_token_score; prefix_offset = input_offset + n_utf8_code_units; struct best_tokenization & current_champ = tokenization_results[prefix_offset]; if (challenger_score > current_champ.score_sum) { - struct best_tokenization challenger = { vocab.special_unk_id, input_offset, (float) challenger_score }; + struct best_tokenization challenger = { vocab.token_unk(), input_offset, challenger_score }; current_champ = challenger; } } @@ -933,7 +903,7 @@ struct llm_tokenizer_ugm { // merge sequences of consecutive unknown tokens into single unknown tokens bool is_prev_unknown = false; for (struct best_tokenization & tokenization = tokenization_results[input_len]; ; tokenization = tokenization_results[tokenization.input_offset]) { - bool is_unknown = tokenization.token_id == vocab.special_unk_id; + bool is_unknown = tokenization.token_id == vocab.token_unk(); if (!(is_prev_unknown && is_unknown)) { output.push_back(tokenization.token_id); } @@ -948,7 +918,6 @@ struct llm_tokenizer_ugm { } private: - const llama_vocab & vocab; // helper structure for returning normalization results struct normalization_result { @@ -961,11 +930,11 @@ struct llm_tokenizer_ugm { normalized->clear(); normalized->reserve(input.size() * 3); - const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " "; + const std::string space = vocab.get_escape_whitespaces() ? tokenizer.escaped_space : " "; - bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; - bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; - bool shall_merge_spaces = vocab.tokenizer_remove_extra_whitespaces; + const bool shall_prepend_space = !vocab.get_treat_whitespace_as_suffix() && vocab.get_add_space_prefix(); + const bool shall_append_space = vocab.get_treat_whitespace_as_suffix() && vocab.get_add_space_prefix(); + const bool shall_merge_spaces = vocab.get_remove_extra_whitespaces(); bool is_space_prepended = false; bool processing_non_ws = false; @@ -1006,7 +975,7 @@ struct llm_tokenizer_ugm { /* * This structure is a view wrapper for XOR-compressed double array (XCDA) * See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries. - * Eeach bit-packed entry contains: + * Each bit-packed entry contains: * - BASE array value in bits 10-30 * - LCHECK array value in bits 0-7 * - LEAF array value in bit 9 @@ -1043,13 +1012,21 @@ struct llm_tokenizer_ugm { size_t xcda_array_size; }; + // this structure stores the best tokenization so far at input_offset + struct best_tokenization { + llama_token token_id; + size_t input_offset; + double score_sum; + }; + struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) { if (input_offset == input.size()) { return { &input[input_offset], 0, 0 }; } // if input prefix matches some user-defined token return this token as normalization result - auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset); + auto user_defined_token_match = + tokenizer.user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset); if (user_defined_token_match.second > 0) { return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second }; } @@ -1057,8 +1034,8 @@ struct llm_tokenizer_ugm { size_t longest_prefix_length = 0; size_t longest_prefix_offset = 0; - if (xcda_array_size > 0) { - struct xcda_array_view xcda_view(xcda_array, xcda_array_size); + if (tokenizer.xcda_array_size > 0) { + struct xcda_array_view xcda_view(tokenizer.xcda_array, tokenizer.xcda_array_size); // Find the longest normalized sequence matching the input prefix by walking // the XOR-compressed compact double array (XCDA) starting from the root node @@ -1094,722 +1071,2720 @@ struct llm_tokenizer_ugm { if (longest_prefix_length > 0) { // we have a match, so return the replacement sequence - if (longest_prefix_offset >= prefix_replacements_size) { + if (longest_prefix_offset >= tokenizer.prefix_replacements_size) { throw std::runtime_error("Index out of array bounds in precompiled charsmap!"); } - const char * prefix_replacement = &prefix_replacements[longest_prefix_offset]; + const char * prefix_replacement = &(tokenizer.prefix_replacements)[longest_prefix_offset]; return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length }; - } else { - // check if the input prefix contains a valid sequence of UTF-8 code units - try { - // if yes, return this sequence unmodified - size_t prefix_offset = input_offset; - unicode_cpt_from_utf8(input, prefix_offset); - return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset }; - } catch (std::invalid_argument & /*ex*/) { - // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER - return { "\xEF\xBF\xBD", 3, 1 }; - } } - } - // escaped space symbol - U+2581 (Lower One Eighth Block) - const std::string escaped_space = "\xE2\x96\x81"; + // check if the input prefix contains a valid sequence of UTF-8 code units + try { + // if yes, return this sequence unmodified + size_t prefix_offset = input_offset; + unicode_cpt_from_utf8(input, prefix_offset); + return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset }; + } catch (std::invalid_argument & /*ex*/) { + // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER + return { "\xEF\xBF\xBD", 3, 1 }; + } + } - const char * prefix_replacements = NULL; - size_t prefix_replacements_size = 0; + const llama_vocab & vocab; + const llm_tokenizer_ugm & tokenizer; +}; - const uint32_t * xcda_array = NULL; - size_t xcda_array_size = 0; +// +// RWKV tokenizer +// - struct naive_trie user_defined_token_matcher; +static std::vector llama_unescape_rwkv_token(const std::string & escaped) { + std::vector output; + output.reserve(escaped.size()); + + // Parser state + bool escaping = false; + uint8_t hex_remaining = 0; + uint8_t hex_acc = 0; + + // Step through characters, performing parsing + for (const char & c : escaped) { + // If we're parsing a hex code, interpret the next character + if (hex_remaining != 0) { + uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0'); + hex_acc = (hex_acc << 4) + value; + + hex_remaining -= 1; + if (hex_remaining == 0) { + output.push_back(hex_acc); + hex_acc = 0; + } - // this structure stores the best tokenization so far at input_offset - struct best_tokenization { - llama_token token_id; - size_t input_offset; - float score_sum; - }; + continue; + } - float min_score = FLT_MAX; - float max_score = -FLT_MAX; + // If we got an escape character, interpret it + if (escaping) { + if (c == 't') { + output.push_back('\t'); + } else if (c == 'n') { + output.push_back('\n'); + } else if (c == 'r') { + output.push_back('\r'); + } else if (c == 'x') { + hex_remaining = 2; + } else { + output.push_back(c); + } - float unknown_token_score_penalty = 10.0; - float unknown_token_score; + escaping = false; + continue; + } - struct naive_trie token_matcher; -}; + if (c == '\\') { + escaping = true; + continue; + } -// -// (de-) tokenize -// + output.push_back(c); + } -typedef enum FRAGMENT_BUFFER_VARIANT_TYPE { - FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN, - FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT -} FRAGMENT_BUFFER_VARIANT_TYPE; + return output; +} -struct fragment_buffer_variant { - fragment_buffer_variant(llama_vocab::id _token) - : - type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN), - token(_token), - raw_text(_dummy), - offset(0), - length(0) {} +struct llm_tokenizer_rwkv : llm_tokenizer { + llm_tokenizer_rwkv(const llama_vocab & vocab) { + // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens. + // For now, we decode the vocab here into the lookup we'll use for tokenization. - fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length) - : - type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT), - token((llama_vocab::id) - 1), - raw_text(_raw_text), - offset(_offset), - length(_length){ - GGML_ASSERT(_offset >= 0); - GGML_ASSERT(_length >= 1); - GGML_ASSERT(offset + length <= raw_text.length()); + // build trie + for (uint32_t id = 0; id < vocab.n_tokens(); ++id) { + const auto & data = vocab.get_token_data(id); + const auto text = llama_unescape_rwkv_token(data.text); + token_matcher.insert((const char *) text.data(), text.size(), id); } + } - const FRAGMENT_BUFFER_VARIANT_TYPE type; - const llama_vocab::id token; - const std::string _dummy; - const std::string & raw_text; - const uint64_t offset; - const uint64_t length; + struct naive_trie token_matcher; }; -// #define PRETOKENIZERDEBUG +struct llm_tokenizer_rwkv_session { + llm_tokenizer_rwkv_session(const llama_vocab & vocab, const llm_tokenizer_rwkv & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + + void tokenize(const std::string & text, std::vector & output) { + uint32_t position = 0; + while (position < text.size()) { + const struct naive_trie * node = tokenizer.token_matcher.traverse(text[position]); + if (node == NULL) { + // no matching token found, add unknown token + output.push_back(vocab.token_unk()); + position += 1; + continue; + } -static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer, bool parse_special) { - // for each special token - for (const llama_vocab::id special_id : vocab.cache_special_tokens) { - const auto & data = vocab.id_to_token[special_id]; - const auto & special_token = data.text; + // traverse the trie to find the longest matching token + uint32_t token_id = 0; + uint32_t token_length = 0; + while (node != NULL) { + if (node->has_value) { + token_id = node->value; + token_length = position + 1; + } + node = node->traverse(text[++position]); + } - if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) { - // Ignore control and unknown tokens when parse_special == false - continue; - // User-defined tokens are still pre-tokenized before everything else - // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726 - // This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.) + // add the longest matching token + output.push_back(token_id); + position = token_length; } + } - // for each text fragment - std::forward_list::iterator it = buffer.begin(); - while (it != buffer.end()) { - auto & fragment = (*it); +private: + const llama_vocab & vocab; + const llm_tokenizer_rwkv & tokenizer; +}; - // if a fragment is text ( not yet processed ) - if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { - auto & raw_text = fragment.raw_text; +struct llm_tokenizer_plamo2 : llm_tokenizer { + llm_tokenizer_plamo2(const llama_vocab & vocab) { + build(vocab); + } - auto raw_text_base_offset = fragment.offset; - auto raw_text_base_length = fragment.length; + void build(const llama_vocab & vocab) { + // Reset internal structures + tokens_.clear(); + bytes_.assign(256, 0); + to_suffix_id_.clear(); + table_.clear(); + + // Build token list and byte mapping + std::unordered_map suffix_to_score; + std::unordered_map token_to_id; + + for (size_t token_id = 0; token_id < vocab.n_tokens(); ++token_id) { + const auto & entry = vocab.get_token_data(token_id); + tokens_.push_back(entry.text); + token_to_id[entry.text] = static_cast(token_id); + + // Handle byte tokens + if (vocab.is_byte(token_id)) { + if (entry.text.length() == 6 && entry.text.substr(0, 3) == "<0x" && entry.text.back() == '>') { + std::string hex_str = entry.text.substr(3, 2); + int byte_val = std::stoi(hex_str, nullptr, 16); + bytes_[byte_val] = static_cast(token_id); + } + continue; + } - // loop over the text - while (true) { - // find the first occurrence of a given special token in this fragment - // passing offset argument only limit the "search area" but match coordinates - // are still relative to the source full raw_text - auto match = raw_text.find(special_token, raw_text_base_offset); + // Add token and all its suffixes to suffix_to_score + suffix_to_score[entry.text] = entry.score; - // no occurrences found, stop processing this fragment for a given special token - if (match == std::string::npos) break; + // Extract suffixes character by character (UTF-8 aware) + std::vector cpts = unicode_cpts_from_utf8(entry.text); + for (size_t i = 1; i < cpts.size(); ++i) { + std::string suffix; + for (size_t j = i; j < cpts.size(); ++j) { + suffix += unicode_cpt_to_utf8(cpts[j]); + } + if (suffix_to_score.find(suffix) == suffix_to_score.end()) { + suffix_to_score[suffix] = std::numeric_limits::quiet_NaN(); + } + } + } - // check if match is within bounds of offset <-> length - if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break; + // Check that all byte tokens are set + for (int i = 0; i < 256; ++i) { + if (bytes_[i] == 0) { + throw std::runtime_error("Byte token for <0x" + std::to_string(i) + "> is not set"); + } + } -#ifdef PRETOKENIZERDEBUG - LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str()); -#endif - auto source = std::distance(buffer.begin(), it); + // Build suffix list in lexicographical order of reversed strings + std::vector suffixes; + for (const auto & pair : suffix_to_score) { + suffixes.push_back(pair.first); + } + suffixes.push_back(""); // Empty suffix - // if match is further than base offset - // then we have some text to the left of it - if (match > raw_text_base_offset) { - // left - const int64_t left_reminder_offset = raw_text_base_offset + 0; - int64_t left_reminder_length = match - raw_text_base_offset; + std::sort(suffixes.begin(), suffixes.end(), [](const std::string & a, const std::string & b) { + std::string rev_a(a.rbegin(), a.rend()); + std::string rev_b(b.rbegin(), b.rend()); + return rev_a < rev_b; + }); - if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) { - while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) { - left_reminder_length--; - } - } + // Build suffix_to_id and to_suffix_id_ + std::unordered_map suffix_to_id; + int32_t num_pieces = 0; - if (left_reminder_length > 0) { - buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length); - it++; - } + for (const auto & suffix : suffixes) { + suffix_to_id[suffix] = num_pieces; + if (!suffix.empty()) { + std::vector cpts = unicode_cpts_from_utf8(suffix); -#ifdef PRETOKENIZERDEBUG - LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str()); -#endif + std::string remaining; + for (size_t i = 1; i < cpts.size(); ++i) { + remaining += unicode_cpt_to_utf8(cpts[i]); + } + + int64_t piece_code = (static_cast(cpts[0]) << 32) | suffix_to_id[remaining]; + to_suffix_id_[piece_code] = num_pieces; + + // Count number of pieces for this suffix + int32_t pieces_for_suffix = 1; // sentinel row + for (int32_t piece_length = static_cast(cpts.size()); piece_length > 0; --piece_length) { + std::string piece; + for (int32_t i = 0; i < piece_length; ++i) { + piece += unicode_cpt_to_utf8(cpts[i]); } + if (suffix_to_score.find(piece) != suffix_to_score.end()) { + pieces_for_suffix++; + } + } + num_pieces += pieces_for_suffix; + } else { + num_pieces++; // Empty suffix contributes one piece (sentinel row) + } + } - // special token - buffer.emplace_after(it, special_id); - it++; + // Build flattened table + table_.resize(num_pieces, std::vector(4, 0)); + int32_t table_idx = 0; - // right - if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) { - int64_t right_reminder_offset = match + special_token.length(); - int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length()); + for (const auto & suffix : suffixes) { + // Add all prefixes of the suffix to the table (in decreasing order of length) + std::vector cpts = unicode_cpts_from_utf8(suffix); + for (int32_t piece_length = static_cast(cpts.size()); piece_length > 0; --piece_length) { + std::string piece; + for (int32_t i = 0; i < piece_length; ++i) { + piece += unicode_cpt_to_utf8(cpts[i]); + } - if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) { - while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) { - right_reminder_offset++; - right_reminder_length--; - } - } + auto score_it = suffix_to_score.find(piece); + if (score_it == suffix_to_score.end()) { + continue; + } - if (right_reminder_length > 0) { - buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length); - it++; - } + table_[table_idx][TABLE_PIECE_LENGTH] = piece_length; + auto token_it = token_to_id.find(piece); + table_[table_idx][TABLE_TOKEN_ID] = (token_it != token_to_id.end()) ? token_it->second : -1; -#ifdef PRETOKENIZERDEBUG - LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str()); -#endif + float score = score_it->second; + table_[table_idx][TABLE_SCORE] = std::isfinite(score) ? + static_cast(std::round(score * 1e4)) : INVALID_SCORE; + table_[table_idx][TABLE_PIECE_ID] = suffix_to_id[piece]; - if (source == 0) { - buffer.erase_after(buffer.before_begin()); - } else { - buffer.erase_after(std::next(buffer.begin(), (source-1))); + table_idx++; + } + + // Add sentinel row + table_[table_idx][TABLE_PIECE_LENGTH] = 1; + table_[table_idx][TABLE_TOKEN_ID] = -1; + table_[table_idx][TABLE_SCORE] = UNKNOWN_SCORE; + table_idx++; + } + } + + std::vector encode(const std::string & text) const { + std::vector unicode_data = unicode_cpts_from_utf8(text); + // Skip the first code point if it is a BOM (Byte Order Mark) + if (!unicode_data.empty() && unicode_data[0] == 0xFEFF) { + unicode_data.erase(unicode_data.begin()); + } + + if (unicode_data.empty()) { + return {}; + } + + const size_t data_len = unicode_data.size(); + + // Initialize scores array (dynamic programming) + std::vector scores(data_len + 1, static_cast(1) << 60); + scores[data_len] = 0; + + // Path array to track best tokenization + std::vector> path(data_len + 1, std::vector(3, 0)); + + int32_t suffix_id = 0; + + // Process from end to beginning + for (int i = static_cast(data_len) - 1; i >= 0; --i) { + uint32_t c = unicode_data[i]; + + // Find next suffix ID + for (size_t p = suffix_id; p < table_.size(); ++p) { + int64_t piece_code = (static_cast(c) << 32) | table_[p][TABLE_PIECE_ID]; + auto it = to_suffix_id_.find(piece_code); + suffix_id = (it != to_suffix_id_.end()) ? it->second : 0; + + if (suffix_id > 0 || table_[p][TABLE_SCORE] == UNKNOWN_SCORE) { + break; + } + } + + // Update best path + for (size_t p = suffix_id; p < table_.size(); ++p) { + int32_t score = table_[p][TABLE_SCORE]; + if (score > INVALID_SCORE) { + int32_t piece_length = table_[p][TABLE_PIECE_LENGTH]; + int64_t s = scores[i + piece_length] - score; + + if (s < scores[i]) { + scores[i] = s; + path[i][PATH_TOKEN_LENGTH] = piece_length; + path[i][PATH_TOKEN_ID] = table_[p][TABLE_TOKEN_ID]; + path[i][PATH_NUM_TOKENS] = path[i + piece_length][PATH_NUM_TOKENS] + 1; + + if (score == UNKNOWN_SCORE) { + // Add UTF-8 byte count + path[i][PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000); } + } + } - // repeat for the right side - raw_text_base_offset = right_reminder_offset; - raw_text_base_length = right_reminder_length; + if (score == UNKNOWN_SCORE) { + break; + } + } + } -#ifdef PRETOKENIZERDEBUG - LLAMA_LOG_WARN("RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str()); -#endif + // Decode the best path + std::vector token_ids; + token_ids.reserve(path[0][PATH_NUM_TOKENS]); + + int pos = 0; + while (pos < static_cast(data_len)) { + if (path[pos][PATH_TOKEN_ID] >= 0) { + token_ids.push_back(path[pos][PATH_TOKEN_ID]); + } else { + // Fall back to byte tokens + uint32_t c = unicode_data[pos]; + int s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000); + + for (int i = 0; i < s; ++i) { + uint8_t b; + if (s == 1) { + b = c; } else { - if (source == 0) { - buffer.erase_after(buffer.before_begin()); + if (i == 0) { + b = (0xF00 >> s) & 0xFF; } else { - buffer.erase_after(std::next(buffer.begin(), (source-1))); + b = 0x80; } - break; } + token_ids.push_back(bytes_[b | ((c >> ((s - i - 1) * 6)) & 0x3F)]); } } - it++; + + assert(path[pos][PATH_TOKEN_LENGTH] > 0); + pos += path[pos][PATH_TOKEN_LENGTH]; } + + return token_ids; } -} +private: + // Constants for table structure + static constexpr int32_t TABLE_PIECE_LENGTH = 0; + static constexpr int32_t TABLE_TOKEN_ID = 1; + static constexpr int32_t TABLE_SCORE = 2; + static constexpr int32_t TABLE_PIECE_ID = 3; -std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) { - std::vector output; - std::forward_list fragment_buffer; + // Constants for path array + static constexpr int32_t PATH_TOKEN_LENGTH = 0; + static constexpr int32_t PATH_TOKEN_ID = 1; + static constexpr int32_t PATH_NUM_TOKENS = 2; - if (!raw_text.empty()) { - fragment_buffer.emplace_front(raw_text, 0, raw_text.length()); - tokenizer_st_partition(vocab, fragment_buffer, parse_special); + // Score constants + static constexpr int32_t INVALID_SCORE = -20000000; + static constexpr int32_t UNKNOWN_SCORE = -10000000; + + // List of tokens in the vocabulary + std::vector tokens_; + + // Mapping from byte code point to token ID (for byte fallback) + std::vector bytes_; + + // Mapping from piece code to suffix ID + std::unordered_map to_suffix_id_; + + // Flattened table representing the Trie structure + // Each row contains: [piece_length, token_id, score, piece_id] + std::vector> table_; +}; + +struct llm_tokenizer_plamo2_session { + llm_tokenizer_plamo2_session(const llm_tokenizer_plamo2 & tokenizer) : tokenizer(tokenizer) {} + + void tokenize(const std::string & text, std::vector & output) { + std::vector tokens = tokenizer.encode(text); + output.insert(output.end(), tokens.begin(), tokens.end()); } - switch (vocab.type) { - case LLAMA_VOCAB_TYPE_SPM: - { - // OG tokenizer behavior: - // - // tokenizer.encode('', add_special_tokens=True) returns [1] - // tokenizer.encode('', add_special_tokens=False) returns [] +private: + const llm_tokenizer_plamo2 & tokenizer; +}; - bool is_prev_special = true; // prefix with space if first token +// +// impl +// - if (add_special && vocab.tokenizer_add_bos) { - GGML_ASSERT(vocab.special_bos_id != -1); - output.push_back(vocab.special_bos_id); - is_prev_special = true; - } +typedef enum FRAGMENT_BUFFER_VARIANT_TYPE { + FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN, + FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT +} FRAGMENT_BUFFER_VARIANT_TYPE; - for (const auto & fragment : fragment_buffer) { - if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { - auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); +struct fragment_buffer_variant { + fragment_buffer_variant(llama_token _token) + : + type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN), + token(_token), + raw_text(_dummy), + offset(0), + length(0) {} - // prefix with space if previous is special - if (vocab.tokenizer_add_space_prefix && is_prev_special) { - raw_text = " " + raw_text; - } + fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length) + : + type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT), + token((llama_token) - 1), + raw_text(_raw_text), + offset(_offset), + length(_length){ + GGML_ASSERT(_offset >= 0); + GGML_ASSERT(_length >= 1); + GGML_ASSERT(offset + length <= raw_text.length()); + } -#ifdef PRETOKENIZERDEBUG - LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); -#endif - llm_tokenizer_spm tokenizer(vocab); - llama_escape_whitespace(raw_text); - tokenizer.tokenize(raw_text, output); - is_prev_special = false; - } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - output.push_back(fragment.token); - is_prev_special = true; - } - } + const FRAGMENT_BUFFER_VARIANT_TYPE type; + const llama_token token; + const std::string _dummy; + const std::string & raw_text; + const uint64_t offset; + const uint64_t length; +}; - if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) { - LLAMA_LOG_WARN( - "%s: Added a BOS token to the prompt as specified by the model but the prompt " - "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " - "Are you sure this is what you want?\n", __FUNCTION__); - } +struct llama_vocab::impl { + uint32_t n_token_types = 0; // for BERT-style token types + + std::string tokenizer_model; + std::string tokenizer_pre; + + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; + enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + + int max_token_len = 0; // used for optimizing longest token search + + // default LLaMA special tokens + // TODO: should we set all of these to LLAMA_TOKEN_NULL? + llama_token special_bos_id = 1; + llama_token special_eos_id = 2; + llama_token special_eot_id = LLAMA_TOKEN_NULL; + llama_token special_eom_id = LLAMA_TOKEN_NULL; + llama_token special_unk_id = 0; + llama_token special_sep_id = LLAMA_TOKEN_NULL; + llama_token special_pad_id = LLAMA_TOKEN_NULL; + llama_token special_mask_id = LLAMA_TOKEN_NULL; + + llama_token linefeed_id = 13; + + // fim tokens + llama_token special_fim_pre_id = LLAMA_TOKEN_NULL; + llama_token special_fim_suf_id = LLAMA_TOKEN_NULL; + llama_token special_fim_mid_id = LLAMA_TOKEN_NULL; + llama_token special_fim_pad_id = LLAMA_TOKEN_NULL; + llama_token special_fim_rep_id = LLAMA_TOKEN_NULL; // repo + llama_token special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator + + // tokenizer flags + bool add_space_prefix = false; + bool add_bos = false; + bool add_eos = false; + bool add_sep = false; + bool ignore_merges = false; + bool clean_spaces = false; // clean_up_tokenization_spaces + bool remove_extra_whitespaces = false; + bool escape_whitespaces = true; + bool treat_whitespace_as_suffix = false; + + std::unordered_map token_to_id; + std::vector id_to_token; + + std::vector cache_special_tokens; + std::vector cache_token_to_piece; // llama_token_to_piece(special = true); + struct pair_hash { + size_t operator()(const std::pair & p) const { + return std::hash{}(p.first) ^ //create some hash for pair + (std::hash{}(p.second) << 1); + } + }; + std::unordered_map, int, pair_hash> bpe_ranks; - if (add_special && vocab.tokenizer_add_eos) { - GGML_ASSERT(vocab.special_eos_id != -1); - output.push_back(vocab.special_eos_id); - } - } break; - case LLAMA_VOCAB_TYPE_BPE: - { - llm_tokenizer_bpe tokenizer(vocab); + // set of all tokens that cause "end of generation" + std::set special_eog_ids; - if (add_special) { - tokenizer.append_bos(output); - } - for (const auto & fragment : fragment_buffer) { - if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { - auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); + std::unique_ptr tokenizer; -#ifdef PRETOKENIZERDEBUG - LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); -#endif - tokenizer.tokenize(raw_text, output); - } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - tokenizer.append(fragment.token, output); - } - } + std::vector precompiled_charsmap; - if (add_special) { - tokenizer.append_eos(output); - tokenizer.check_double_bos_eos(output); + impl(const llama_vocab & vocab) : vocab(vocab) { + } + + ~impl() = default; + + void load(llama_model_loader & ml, const LLM_KV & kv); + + enum llama_vocab_type get_type() const; + + std::string type_name() const; + + bool is_normal (llama_token id) const; + bool is_unknown (llama_token id) const; + bool is_control (llama_token id) const; + bool is_byte (llama_token id) const; + bool is_user_defined(llama_token id) const; + bool is_unused (llama_token id) const; + bool is_eog (llama_token id) const; + + uint8_t token_to_byte(llama_token id) const; + + llama_token_attr token_get_attr(llama_token id) const; + + void init_tokenizer(enum llama_vocab_type type); + + void tokenizer_st_partition(std::forward_list & buffer, bool parse_special) const; + + std::string token_to_piece_for_cache( + llama_token token, + bool special) const; + + + std::vector tokenize( + const std::string & raw_text, + bool add_special, + bool parse_special = false) const; + + int32_t tokenize( + const char * text, + int32_t text_len, + llama_token * tokens, + int32_t n_tokens_max, + bool add_special, + bool parse_special) const; + + // does not write null-terminator to buf + int32_t token_to_piece( + llama_token token, + char * buf, + int32_t length, + int32_t lstrip, + bool special) const; + + // use cached data + const std::string & token_to_piece(llama_token token) const; + + int32_t detokenize( + const llama_token * tokens, + int32_t n_tokens, + char * text, + int32_t text_len_max, + bool remove_special, + bool unparse_special) const; + + std::string detokenize( + const std::vector & tokens, + bool special) const; + + void print_info() const; + +private: + const llama_vocab & vocab; +}; + +void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { + gguf_context * ctx = ml.meta; + + // determine vocab type + { + ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); + ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); + + ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, n_token_types, false); + + if (tokenizer_model == "no_vocab" || tokenizer_model == "none") { + type = LLAMA_VOCAB_TYPE_NONE; + + // default special tokens + special_bos_id = LLAMA_TOKEN_NULL; + special_eos_id = LLAMA_TOKEN_NULL; + special_unk_id = LLAMA_TOKEN_NULL; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + special_mask_id = LLAMA_TOKEN_NULL; + linefeed_id = LLAMA_TOKEN_NULL; + + // read vocab size from metadata + uint32_t n_tokens = 0; + if (ml.get_key(LLM_KV_VOCAB_SIZE, n_tokens, false)) { + LLAMA_LOG_WARN("%s: adding %u dummy tokens\n", __func__, n_tokens); + id_to_token.resize(n_tokens); + } + + return; + } + + if (tokenizer_model == "llama") { + type = LLAMA_VOCAB_TYPE_SPM; + + // default special tokens + special_bos_id = 1; + special_eos_id = 2; + special_unk_id = 0; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + special_mask_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "bert") { + type = LLAMA_VOCAB_TYPE_WPM; + + // default special tokens + special_bos_id = 101; + special_eos_id = LLAMA_TOKEN_NULL; + special_unk_id = 100; + special_sep_id = 102; + special_pad_id = 0; + special_mask_id = 103; + + add_sep = true; + } else if (tokenizer_model == "gpt2") { + type = LLAMA_VOCAB_TYPE_BPE; + + // read bpe merges and populate bpe ranks + const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + if (merges_keyidx == -1) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + + std::string first; + std::string second; + + const size_t pos = word.find(' ', 1); + + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); } - } break; - case LLAMA_VOCAB_TYPE_WPM: - { - if (add_special) { - GGML_ASSERT(vocab.special_cls_id != -1); - output.push_back(vocab.special_cls_id); + + bpe_ranks.emplace(std::make_pair(first, second), i); + } + + // default special tokens + special_bos_id = 11; + special_eos_id = 11; + special_unk_id = LLAMA_TOKEN_NULL; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + special_mask_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "t5") { + type = LLAMA_VOCAB_TYPE_UGM; + + // default special tokens + special_bos_id = LLAMA_TOKEN_NULL; + special_eos_id = 1; + special_unk_id = 2; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = 0; + special_mask_id = LLAMA_TOKEN_NULL; + + const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); + if (precompiled_charsmap_keyidx != -1) { + const gguf_type pc_type = gguf_get_arr_type(ctx, precompiled_charsmap_keyidx); + GGML_ASSERT(pc_type == GGUF_TYPE_INT8 || pc_type == GGUF_TYPE_UINT8); + + const size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); + const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); + precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); +#ifdef IS_BIG_ENDIAN + // correct endiannes of data in precompiled_charsmap binary blob + uint32_t * xcda_blob_size = (uint32_t *) &precompiled_charsmap[0]; + *xcda_blob_size = __builtin_bswap32(*xcda_blob_size); + assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap); + size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t); + uint32_t * xcda_array = (uint32_t *) &precompiled_charsmap[sizeof(uint32_t)]; + for (size_t i = 0; i < xcda_array_size; ++i) { + xcda_array[i] = __builtin_bswap32(xcda_array[i]); } +#endif + } + } else if (tokenizer_model == "rwkv") { + type = LLAMA_VOCAB_TYPE_RWKV; + + // default special tokens + special_bos_id = LLAMA_TOKEN_NULL; + special_eos_id = LLAMA_TOKEN_NULL; + special_unk_id = LLAMA_TOKEN_NULL; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "plamo2") { + type = LLAMA_VOCAB_TYPE_PLAMO2; + + // PLaMo-2 default special tokens (these will be overridden by model config) + special_bos_id = 1; // <|plamo:bos|> + special_eos_id = 2; // <|plamo:eos|> + special_unk_id = 0; // <|plamo:unk|> + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = 3; // <|plamo:pad|> + special_mask_id = LLAMA_TOKEN_NULL; + } else { + throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); + } + + // for now, only BPE models have pre-tokenizers + if (type == LLAMA_VOCAB_TYPE_BPE) { + add_space_prefix = false; + clean_spaces = true; + if (tokenizer_pre.empty()) { + LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__); + LLAMA_LOG_WARN("%s: \n", __func__); + LLAMA_LOG_WARN("%s: ************************************ \n", __func__); + LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__); + LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__); + LLAMA_LOG_WARN("%s: ************************************ \n", __func__); + LLAMA_LOG_WARN("%s: \n", __func__); + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if (tokenizer_pre == "default") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if ( + tokenizer_pre == "llama3" || + tokenizer_pre == "llama-v3" || + tokenizer_pre == "llama-bpe"|| + tokenizer_pre == "falcon3" || + tokenizer_pre == "falcon-h1" || + tokenizer_pre == "pixtral" || + tokenizer_pre == "midm-2.0" || + tokenizer_pre == "lfm2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3; + ignore_merges = true; + add_bos = true; + } else if ( + tokenizer_pre == "deepseek-llm") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM; + clean_spaces = false; + } else if ( + tokenizer_pre == "deepseek-coder") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER; + clean_spaces = false; + } else if ( + tokenizer_pre == "deepseek-v3") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; + clean_spaces = false; + } else if ( + tokenizer_pre == "falcon") { + pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON; + } else if ( + tokenizer_pre == "mpt") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MPT; + } else if ( + tokenizer_pre == "starcoder") { + pre_type = LLAMA_VOCAB_PRE_TYPE_STARCODER; + } else if ( + tokenizer_pre == "gpt-2" || + tokenizer_pre == "phi-2" || + tokenizer_pre == "jina-es" || + tokenizer_pre == "jina-de" || + tokenizer_pre == "gigachat" || + tokenizer_pre == "jina-v2-es" || + tokenizer_pre == "jina-v2-de" || + tokenizer_pre == "a.x-4.0" || + tokenizer_pre == "mellum") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "jina-v1-en" || + tokenizer_pre == "jina-v2-code" || + tokenizer_pre == "roberta-bpe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + add_sep = true; + } else if ( + tokenizer_pre == "refact") { + pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT; + } else if ( + tokenizer_pre == "command-r") { + pre_type = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; + clean_spaces = false; + } else if ( + tokenizer_pre == "qwen2" || + tokenizer_pre == "deepseek-r1-qwen") { + pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; + clean_spaces = false; + } else if ( + tokenizer_pre == "stablelm2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_STABLELM2; + } else if ( + tokenizer_pre == "olmo") { + pre_type = LLAMA_VOCAB_PRE_TYPE_OLMO; + } else if ( + tokenizer_pre == "dbrx") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DBRX; + } else if ( + tokenizer_pre == "smaug-bpe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SMAUG; + } else if ( + tokenizer_pre == "poro-chat") { + pre_type = LLAMA_VOCAB_PRE_TYPE_PORO; + clean_spaces = false; + } else if ( + tokenizer_pre == "glm4" || + tokenizer_pre == "chatglm-bpe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; + special_bos_id = LLAMA_TOKEN_NULL; + } else if ( + tokenizer_pre == "viking") { + pre_type = LLAMA_VOCAB_PRE_TYPE_VIKING; + clean_spaces = false; + } else if ( + tokenizer_pre == "jais") { + pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS; + } else if ( + tokenizer_pre == "tekken") { + pre_type = LLAMA_VOCAB_PRE_TYPE_TEKKEN; + clean_spaces = false; + ignore_merges = true; + add_bos = true; + } else if ( + tokenizer_pre == "smollm") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SMOLLM; + clean_spaces = false; + } else if ( + tokenizer_pre == "codeshell") { + pre_type = LLAMA_VOCAB_PRE_TYPE_CODESHELL; + } else if ( + tokenizer_pre == "bloom") { + pre_type = LLAMA_VOCAB_PRE_TYPE_BLOOM; + } else if ( + tokenizer_pre == "gpt3-finnish") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH; + } else if ( + tokenizer_pre == "exaone") { + pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE; + } else if ( + tokenizer_pre == "exaone4") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "chameleon") { + pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; + add_bos = true; + clean_spaces = false; + } else if ( + tokenizer_pre == "minerva-7b") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MINERVA; + } else if ( + tokenizer_pre == "megrez") { + pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; + } else if ( + tokenizer_pre == "gpt-4o" || + tokenizer_pre == "llama4") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; + clean_spaces = false; + } else if ( + tokenizer_pre == "superbpe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE; + clean_spaces = false; + } else if ( + tokenizer_pre == "trillion") { + pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION; + clean_spaces = false; + } else if ( + tokenizer_pre == "bailingmoe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; + clean_spaces = false; + } else if ( + tokenizer_pre == "seed-coder") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER; + clean_spaces = false; + } else if ( + tokenizer_pre == "hunyuan") { + pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN; + clean_spaces = false; + } else if ( + tokenizer_pre == "hunyuan-dense") { + pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE; + clean_spaces = false; + } else if ( + tokenizer_pre == "kimi-k2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; + clean_spaces = false; + } else { + throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); + } + } else if (type == LLAMA_VOCAB_TYPE_SPM) { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + add_space_prefix = true; + clean_spaces = false; + add_bos = true; + add_eos = false; + } else if (type == LLAMA_VOCAB_TYPE_WPM) { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + add_space_prefix = false; + clean_spaces = true; + add_bos = true; + add_eos = false; + add_sep = true; + } else if (type == LLAMA_VOCAB_TYPE_UGM) { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + add_bos = false; + add_eos = true; + } else if (type == LLAMA_VOCAB_TYPE_RWKV) { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + add_space_prefix = false; + clean_spaces = false; + add_bos = false; + add_eos = false; + } else { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } - llm_tokenizer_wpm tokenizer(vocab); + ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX, add_space_prefix, false); + ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, remove_extra_whitespaces, false); + } - for (const auto & fragment : fragment_buffer) { - if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { - auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); + const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); + if (token_idx == -1) { + throw std::runtime_error("cannot find tokenizer vocab in model file\n"); + } -#ifdef PRETOKENIZERDEBUG - LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); -#endif - tokenizer.tokenize(raw_text, output); - } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - output.push_back(fragment.token); + const float * scores = nullptr; + const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); + if (score_idx != -1) { + scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + } + + const int * toktypes = nullptr; + const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); + if (toktype_idx != -1) { + toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); + } + + uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx); + id_to_token.resize(n_tokens); + + for (uint32_t i = 0; i < n_tokens; i++) { + std::string word = gguf_get_arr_str(ctx, token_idx, i); + if (word.empty()) { + LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i); + word = "[EMPTY_" + std::to_string(i) + "]"; + } + + token_to_id[word] = i; + max_token_len = std::max(max_token_len, (int) word.size()); + + auto & token_data = id_to_token[i]; + token_data.text = std::move(word); + token_data.score = scores ? scores[i] : 0.0f; + token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; + + if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file + switch(toktypes[i]) { + case LLAMA_TOKEN_TYPE_UNKNOWN: token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN; break; + case LLAMA_TOKEN_TYPE_UNUSED: token_data.attr = LLAMA_TOKEN_ATTR_UNUSED; break; + case LLAMA_TOKEN_TYPE_NORMAL: token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; break; + case LLAMA_TOKEN_TYPE_CONTROL: token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; break; + case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break; + case LLAMA_TOKEN_TYPE_BYTE: token_data.attr = LLAMA_TOKEN_ATTR_BYTE; break; + case LLAMA_TOKEN_TYPE_UNDEFINED: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; + default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; + } + } + } + GGML_ASSERT(id_to_token.size() == token_to_id.size()); + + init_tokenizer(type); + + // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' + if (type == LLAMA_VOCAB_TYPE_SPM) { + try { + linefeed_id = vocab.byte_to_token('\n'); + } catch (const std::exception & e) { + LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what()); + linefeed_id = special_pad_id; + } + } else if (type == LLAMA_VOCAB_TYPE_WPM) { + linefeed_id = special_pad_id; + } else if (type == LLAMA_VOCAB_TYPE_RWKV) { + const std::vector ids = tokenize("\n", false); + GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); + linefeed_id = ids[0]; + } else { + const std::vector ids = tokenize("\n", false); + + //GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); + if (ids.empty()) { + LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__); + linefeed_id = special_pad_id; + } else { + linefeed_id = ids[0]; + } + } + + // special tokens + { + const std::vector> special_token_types = { + { LLM_KV_TOKENIZER_BOS_ID, special_bos_id }, + { LLM_KV_TOKENIZER_EOS_ID, special_eos_id }, + { LLM_KV_TOKENIZER_EOT_ID, special_eot_id }, + { LLM_KV_TOKENIZER_EOM_ID, special_eom_id }, + { LLM_KV_TOKENIZER_UNK_ID, special_unk_id }, + { LLM_KV_TOKENIZER_SEP_ID, special_sep_id }, + { LLM_KV_TOKENIZER_PAD_ID, special_pad_id }, + { LLM_KV_TOKENIZER_MASK_ID, special_mask_id }, + { LLM_KV_TOKENIZER_FIM_PRE_ID, special_fim_pre_id }, + { LLM_KV_TOKENIZER_FIM_SUF_ID, special_fim_suf_id }, + { LLM_KV_TOKENIZER_FIM_MID_ID, special_fim_mid_id }, + { LLM_KV_TOKENIZER_FIM_PAD_ID, special_fim_pad_id }, + { LLM_KV_TOKENIZER_FIM_REP_ID, special_fim_rep_id }, + { LLM_KV_TOKENIZER_FIM_SEP_ID, special_fim_sep_id }, + + // deprecated + { LLM_KV_TOKENIZER_PREFIX_ID, special_fim_pre_id }, + { LLM_KV_TOKENIZER_SUFFIX_ID, special_fim_suf_id }, + { LLM_KV_TOKENIZER_MIDDLE_ID, special_fim_mid_id }, + }; + + for (const auto & it : special_token_types) { + const std::string & key = kv(std::get<0>(it)); + int32_t & id = std::get<1>(it); + + uint32_t new_id; + if (!ml.get_key(std::get<0>(it), new_id, false)) { + continue; + } + if (new_id >= id_to_token.size()) { + LLAMA_LOG_WARN("%s: bad special token: '%s' = %u, using default id %d\n", + __func__, key.c_str(), new_id, id); + } else { + id = new_id; + } + } + + // Handle add_bos, add_eos and add_sep + { + bool temp = true; + + if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) { + add_bos = temp; + } + if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) { + add_eos = temp; + } + if (ml.get_key(LLM_KV_TOKENIZER_ADD_SEP, temp, false)) { + add_sep = temp; + } + } + + // auto-detect special tokens by text + // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_... + // for now, we apply this workaround to find the tokens based on their text + + for (const auto & t : token_to_id) { + // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc. + if (special_eot_id == LLAMA_TOKEN_NULL) { + if (false + || t.first == "<|eot_id|>" + || t.first == "<|im_end|>" + || t.first == "<|end|>" + || t.first == "" + || t.first == "<|endoftext|>" + || t.first == "" + || t.first == "_" + || t.first == "<|end▁of▁sentence|>" // DeepSeek + || t.first == "" // smoldocling + ) { + special_eot_id = t.second; + if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", + __func__, t.second, t.first.c_str()); + id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; } } + } - if (add_special) { - GGML_ASSERT(vocab.special_sep_id != -1); - output.push_back(vocab.special_sep_id); + // find EOM token: "<|eom_id|>" + if (special_eom_id == LLAMA_TOKEN_NULL) { + if (false + || t.first == "<|eom_id|>" + ) { + special_eom_id = t.second; + if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", + __func__, t.second, t.first.c_str()); + id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + } } - } break; - case LLAMA_VOCAB_TYPE_UGM: - { - llm_tokenizer_ugm tokenizer(vocab); + } - if (add_special && vocab.tokenizer_add_bos != 0) { - GGML_ASSERT(vocab.special_bos_id != -1); - output.push_back(vocab.special_bos_id); + // find FIM_PRE token: "<|fim_prefix|>", "", "
", etc.
+            if (special_fim_pre_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_prefix|>"  // Qwen
+                        || t.first == ""
+                        || t.first == ""    // Granite
+                        || t.first == "<|fim▁begin|>" // DeepSeek
+                        || t.first == "
"
+                        || t.first == "▁
"          // CodeLlama
+                        || t.first == "<|code_prefix|>" // GLM-4.5
+                        ) {
+                    special_fim_pre_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
                 }
+            }
 
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        tokenizer.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
+            // find FIM_SUF token: "<|fim_suffix|>", "", "", etc.
+            if (special_fim_suf_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_suffix|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""   // Granite
+                        || t.first == "<|fim▁hole|>" // DeepSeek
+                        || t.first == ""
+                        || t.first == "▁"         // CodeLlama
+                        || t.first == "<|code_suffix|>" // GLM-4.5
+                        ) {
+                    special_fim_suf_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
                     }
                 }
+            }
 
-                if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
-                    LLAMA_LOG_WARN(
-                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
-                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
-                        "Are you sure this is what you want?\n", __FUNCTION__);
+            // find FIM_MID token: "<|fim_middle|>", "", "", etc.
+            if (special_fim_mid_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_middle|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""   // Granite
+                        || t.first == "<|fim▁end|>"  // DeepSeek
+                        || t.first == ""
+                        || t.first == "▁"         // CodeLlama
+                        || t.first == "<|code_middle|>" // GLM-4.5
+                        ) {
+                    special_fim_mid_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_PAD token: "<|fim_pad|>", "", "", etc.
+            if (special_fim_pad_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_pad|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""   // Granite
+                        || t.first == ""
+                        ) {
+                    special_fim_pad_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_REP token: "<|fim_repo|>", "", "", etc.
+            if (special_fim_rep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_repo|>"  // Qwen
+                        || t.first == "<|repo_name|>"
+                        || t.first == ""
+                        || t.first == ""
+                        || t.first == ""    // Granite
+                        ) {
+                    special_fim_rep_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_SEP token: "<|file_sep|>"
+            if (special_fim_sep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|file_sep|>" // Qwen
+                        ) {
+                    special_fim_sep_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+        }
+
+        // maintain a list of tokens that cause end-of-generation
+        // this is currently determined based on the token text, which is obviously not ideal
+        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
+        special_eog_ids.clear();
+
+        if (special_fim_pad_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_pad_id) == 0) {
+            special_eog_ids.insert(special_fim_pad_id);
+        }
+
+        if (special_fim_rep_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_rep_id) == 0) {
+            special_eog_ids.insert(special_fim_rep_id);
+        }
+
+        if (special_fim_sep_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_sep_id) == 0) {
+            special_eog_ids.insert(special_fim_sep_id);
+        }
+
+        for (const auto & t : token_to_id) {
+            if (false
+                    || t.first == "<|eot_id|>"
+                    || t.first == "<|im_end|>"
+                    || t.first == "<|end|>"
+                    || t.first == "<|return|>" // o200k_harmony
+                    || t.first == "<|call|>"   // o200k_harmony
+                    || t.first == ""
+                    || t.first == "<|endoftext|>"
+                    || t.first == "<|eom_id|>"
+                    || t.first == ""
+                    || t.first == "_"
+                    || t.first == "<|end_of_text|>"
+                    || t.first == "" // smoldocling
+               ) {
+                special_eog_ids.insert(t.second);
+                if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                            __func__, t.second, t.first.c_str());
+                    id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                }
+            } else {
+                // token is control, but not marked as EOG -> print a debug log
+                if (id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && special_eog_ids.count(t.second) == 0) {
+                    LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
+                            __func__, t.second, t.first.c_str());
+                }
+            }
+        }
+
+        // @ngxson : quick hack for gpt-oss, always render these tokens
+        for (const auto & t : token_to_id) {
+            if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>") {
+                id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
+            }
+        }
+
+        // sanity checks
+        if (special_eos_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eos_id) == 0) {
+            special_eog_ids.insert(special_eos_id);
+            LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (special_eot_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eot_id) == 0) {
+            special_eog_ids.insert(special_eot_id);
+            LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (special_eom_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eom_id) == 0) {
+            special_eog_ids.insert(special_eom_id);
+            LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        // TODO: workaround for o200k_harmony tokenizer: the "<|end|>" token should not be EOG
+        //       we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens,
+        //       we remove the "<|end|>" token from the EOG list
+        {
+            bool has_return = false;
+            bool has_call   = false;
+            bool has_end    = false;
+
+            llama_token end_id = LLAMA_TOKEN_NULL;
+
+            LLAMA_LOG_INFO("%s: printing all EOG tokens:\n", __func__);
+            for (auto tid : special_eog_ids) {
+                LLAMA_LOG_INFO("%s:   - %d ('%s')\n", __func__, tid, id_to_token[tid].text.c_str());
+
+                if (id_to_token[tid].text == "<|return|>") {
+                    has_return = true;
+                } else if (id_to_token[tid].text == "<|call|>") {
+                    has_call = true;
+                } else if (id_to_token[tid].text == "<|end|>") {
+                    has_end = true;
+                    end_id = tid;
+                }
+            }
+
+            if (has_return && has_call && has_end) {
+                special_eog_ids.erase(end_id);
+                LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
+            }
+        }
+    }
+
+    // build special tokens cache
+    {
+        for (llama_token id = 0; id < (llama_token) n_tokens; ++id) {
+            if (id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
+                cache_special_tokens.push_back(id);
+            }
+        }
+
+        std::sort(cache_special_tokens.begin(), cache_special_tokens.end(),
+            [&] (const llama_token a, const llama_token b) {
+                return id_to_token[a].text.size() > id_to_token[b].text.size();
+            }
+        );
+
+        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t) cache_special_tokens.size());
+    }
+
+    // build token to piece cache
+    {
+        size_t size_cache = 0;
+
+        std::vector cache(n_tokens);
+
+        for (uint32_t id = 0; id < n_tokens; ++id) {
+            cache[id] = token_to_piece_for_cache(id, true);
+
+            size_cache += cache[id].size();
+        }
+
+        std::swap(cache_token_to_piece, cache);
+
+        LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
+    }
+
+    // Handle per token attributes
+    //NOTE: Each model customizes per token attributes.
+    //NOTE: Per token attributes are missing from the GGUF file.
+    //TODO: Extract attributes from GGUF file.
+    {
+        auto _contains_any = [] (const std::string & str, const std::vector & substrs) -> bool {
+            for (const auto & substr : substrs) {
+                if (str.find(substr) != std::string::npos) {
+                    return true;
+                }
+            }
+            return false;
+        };
+
+        auto _set_tokenid_attr = [&] (const llama_token id, llama_token_attr attr, bool value) {
+            uint32_t current = id_to_token.at(id).attr;
+            current = value ? (current | attr) : (current & ~attr);
+            id_to_token[id].attr = (llama_token_attr) current;
+        };
+
+        auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
+            _set_tokenid_attr(token_to_id.at(token), attr, value);
+        };
+
+        std::string model_name;
+        std::string tokenizer_pre;
+        std::string general_arch;
+
+        ml.get_key(LLM_KV_GENERAL_NAME,  model_name,    false);
+        ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
+        ml.get_key(LLM_KV_GENERAL_ARCHITECTURE, general_arch, false);
+
+        // model name to lowercase
+        std::transform(model_name.begin(), model_name.end(), model_name.begin(),
+            [] (const std::string::value_type x) {
+                return std::tolower(x);
+            }
+        );
+
+        // set attributes by model/tokenizer/architecture name
+        if (false
+                || _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})
+                || _contains_any(general_arch, {"nomic-bert-moe"})
+           ) {
+            if (token_to_id.count("") == 0) {
+                LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__);
+            } else {
+                _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
+            }
+        } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
+            for (auto id : cache_special_tokens) {
+                _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (const auto * token : {""}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (const auto * token : {"", "", "<|endoftext|>"}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
+            }
+        }
+    }
+}
+
+enum llama_vocab_type llama_vocab::impl::get_type() const {
+    return type;
+}
+
+std::string llama_vocab::impl::type_name() const{
+    switch (type) {
+        case LLAMA_VOCAB_TYPE_NONE:   return "no vocab";
+        case LLAMA_VOCAB_TYPE_SPM:    return "SPM";
+        case LLAMA_VOCAB_TYPE_BPE:    return "BPE";
+        case LLAMA_VOCAB_TYPE_WPM:    return "WPM";
+        case LLAMA_VOCAB_TYPE_UGM:    return "UGM";
+        case LLAMA_VOCAB_TYPE_RWKV:   return "RWKV";
+        case LLAMA_VOCAB_TYPE_PLAMO2: return "PLaMo2";
+        default:                      return "unknown";
+    }
+}
+
+bool llama_vocab::impl::is_normal(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
+}
+
+bool llama_vocab::impl::is_unknown(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
+}
+
+bool llama_vocab::impl::is_control(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
+}
+
+bool llama_vocab::impl::is_byte(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
+}
+
+bool llama_vocab::impl::is_user_defined(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
+}
+
+bool llama_vocab::impl::is_unused(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED;
+}
+
+bool llama_vocab::impl::is_eog(llama_token id) const {
+    return id != LLAMA_TOKEN_NULL && special_eog_ids.count(id) > 0;
+}
+
+uint8_t llama_vocab::impl::token_to_byte(llama_token id) const {
+    GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
+    GGML_ASSERT(is_byte(id));
+    const auto & token_data = id_to_token.at(id);
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
+            auto buf = token_data.text.substr(3, 2);
+            return strtol(buf.c_str(), NULL, 16);
+        }
+        case LLAMA_VOCAB_TYPE_BPE: {
+            GGML_ABORT("fatal error");
+        }
+        case LLAMA_VOCAB_TYPE_WPM: {
+            GGML_ABORT("fatal error");
+        }
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+llama_token_attr llama_vocab::impl::token_get_attr(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token.at(id).attr;
+}
+
+void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) {
+    LLAMA_LOG_DEBUG("%s: initializing tokenizer for type %d\n", __func__, type);
+
+    switch (type) {
+        case LLAMA_VOCAB_TYPE_SPM:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_BPE:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_WPM:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_UGM:
+            tokenizer = std::make_unique(vocab, precompiled_charsmap);
+            break;
+        case LLAMA_VOCAB_TYPE_RWKV:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_PLAMO2:
+            tokenizer = std::make_unique(vocab);
+            break;
+        default:
+            GGML_ABORT("unsupported vocab type");
+    }
+}
+
+//
+// (de-) tokenize
+//
+
+// #define PRETOKENIZERDEBUG
+
+void llama_vocab::impl::tokenizer_st_partition(std::forward_list & buffer, bool parse_special) const {
+    // for each special token
+    for (const llama_token special_id : cache_special_tokens) {
+        const auto & data = vocab.get_token_data(special_id);
+        const auto & text = data.text;
+
+        if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) {
+            // Ignore control and unknown tokens when parse_special == false
+            continue;
+            // User-defined tokens are still pre-tokenized before everything else
+            // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726
+            // This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.)
+        }
+
+        // for each text fragment
+        std::forward_list::iterator it = buffer.begin();
+        while (it != buffer.end()) {
+            auto & fragment = (*it);
+
+            // if a fragment is text ( not yet processed )
+            if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                const auto & raw_text = fragment.raw_text;
+
+                auto raw_text_base_offset = fragment.offset;
+                auto raw_text_base_length = fragment.length;
+
+                // loop over the text
+                while (true) {
+                    // find the first occurrence of a given special token in this fragment
+                    //  passing offset argument only limit the "search area" but match coordinates
+                    //  are still relative to the source full raw_text
+                    //  string_view begins at pos 0 for the same reason
+                    auto match = std::string_view(raw_text.data(), raw_text_base_offset + raw_text_base_length).find(text, raw_text_base_offset);
+
+                    // no occurrences found, stop processing this fragment for a given special token
+                    if (match == std::string::npos) break;
+
+#ifdef PRETOKENIZERDEBUG
+                    LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
+#endif
+                    auto source = std::distance(buffer.begin(), it);
+
+                    // if match is further than base offset
+                    //  then we have some text to the left of it
+                    if (match > raw_text_base_offset) {
+                        // left
+                        const int64_t left_reminder_offset = raw_text_base_offset + 0;
+                        int64_t left_reminder_length = match - raw_text_base_offset;
+
+                        if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) {
+                            while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) {
+                                left_reminder_length--;
+                            }
+                        }
+
+                        if (left_reminder_length > 0) {
+                            buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
+                            it++;
+                        }
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
+#endif
+                    }
+
+                    // special token
+                    buffer.emplace_after(it, special_id);
+                    it++;
+
+                    // right
+                    if (match + text.length() < raw_text_base_offset + raw_text_base_length) {
+                        int64_t right_reminder_offset = match + text.length();
+                        int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + text.length());
+
+                        if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
+                            while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
+                                right_reminder_offset++;
+                                right_reminder_length--;
+                            }
+                        }
+
+                        if (right_reminder_length > 0) {
+                            buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
+                            it++;
+                        }
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
+#endif
+
+                        if (source == 0) {
+                            buffer.erase_after(buffer.before_begin());
+                        } else {
+                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
+                        }
+
+                        // repeat for the right side
+                        raw_text_base_offset = right_reminder_offset;
+                        raw_text_base_length = right_reminder_length;
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
+#endif
+                    } else {
+                        if (source == 0) {
+                            buffer.erase_after(buffer.before_begin());
+                        } else {
+                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
+                        }
+                        break;
+                    }
                 }
+            }
+            it++;
+        }
+    }
+}
+
+// NOTE: avoid ever using this except for building the token_to_piece caches
+std::string llama_vocab::impl::token_to_piece_for_cache(llama_token token, bool special) const {
+    std::string piece;
+    piece.resize(piece.capacity());  // using string internal cache
+    const int n_chars = vocab.token_to_piece(token, &piece[0], piece.size(), 0, special);
+    if (n_chars < 0) {
+        piece.resize(-n_chars);
+        int check = vocab.token_to_piece(token, &piece[0], piece.size(), 0, special);
+        GGML_ASSERT(check == -n_chars);
+    }
+    else {
+        piece.resize(n_chars);
+    }
+
+    return piece;
+}
+
+static void llama_escape_whitespace(std::string & text) {
+    replace_all(text, " ", "\xe2\x96\x81");
+}
+
+static void llama_unescape_whitespace(std::string & word) {
+    replace_all(word, "\xe2\x96\x81", " ");
+}
+
+static std::string llama_decode_text(const std::string & text) {
+    std::string decoded_text;
+
+    const auto cpts = unicode_cpts_from_utf8(text);
+    for (const auto cpt : cpts) {
+        const auto utf8 = unicode_cpt_to_utf8(cpt);
+        try {
+            decoded_text += unicode_utf8_to_byte(utf8);
+        } catch (const std::out_of_range & /*e*/) {
+            decoded_text += "[UNK_BYTE_0x";
+            for (const auto c : utf8) {
+                decoded_text += format("%02x", (uint8_t) c);
+            }
+            decoded_text += text + "]";
+        }
+    }
+
+    return decoded_text;
+}
+
+std::vector llama_vocab::impl::tokenize(
+        const std::string & raw_text,
+        bool add_special,
+        bool parse_special) const {
+    GGML_ASSERT(tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+
+    std::vector output;
+    std::forward_list fragment_buffer;
+
+    if (!raw_text.empty()) {
+        fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
+        tokenizer_st_partition(fragment_buffer, parse_special);
+    }
+
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+            {
+                // OG tokenizer behavior:
+                //
+                // tokenizer.encode('', add_special_tokens=True)  returns [1]
+                // tokenizer.encode('', add_special_tokens=False) returns []
+
+                bool is_prev_special = true;  // prefix with space if first token
+
+                if (add_special && add_bos) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                    is_prev_special = true;
+                }
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text;
+
+                        // prefix with space if previous is special
+                        if (add_space_prefix && is_prev_special) {
+                            text = ' ';
+                        }
+
+                        text += fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        llama_escape_whitespace(text);
+                        llm_tokenizer_spm_session session(vocab);
+                        session.tokenize(text, output);
+                        is_prev_special = false;
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                        is_prev_special = true;
+                    }
+                }
+
+                if (add_special && add_bos && output.size() >= 2 && output[1] == special_bos_id) {
+                    LLAMA_LOG_WARN(
+                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                        "Are you sure this is what you want?\n", __FUNCTION__);
+                }
+
+                if (add_special && add_eos) {
+                    GGML_ASSERT(special_eos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_eos_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_BPE:
+            {
+                llm_tokenizer_bpe_session session(vocab, *static_cast(tokenizer.get()));
+                // it calls some other methods that are not exist in llm_tokenizer,
+                // here just cast it to bpe tokenizer object
+                if (add_special) {
+                    session.append_bos(output);
+                }
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        session.append(fragment.token, output);
+                    }
+                }
+
+                if (add_special) {
+                    session.append_eos(output);
+                    session.check_double_bos_eos(output);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_WPM:
+            {
+                if (add_special) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                }
+
+                llm_tokenizer_wpm_session session(vocab);
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+
+                if (add_special) {
+                    GGML_ASSERT(special_sep_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_sep_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_UGM:
+            {
+                if (add_special && add_bos) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                }
+                llm_tokenizer_ugm_session session(vocab, *static_cast(tokenizer.get()));
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+
+                if (add_special && add_bos && output.size() >= 2 && output[1] == special_bos_id) {
+                    LLAMA_LOG_WARN(
+                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                        "Are you sure this is what you want?\n", __FUNCTION__);
+                }
+
+                if (add_special && add_eos) {
+                    GGML_ASSERT(special_eos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_eos_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_RWKV:
+            {
+                llm_tokenizer_rwkv_session session(vocab, *static_cast(tokenizer.get()));
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_PLAMO2:
+            {
+                llm_tokenizer_plamo2_session session(*static_cast(tokenizer.get()));
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_NONE:
+            GGML_ABORT("fatal error");
+    }
+
+    return output;
+}
+
+int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
+    // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
+    static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
+    const llama_token_attr attr = token_get_attr(token);
+    if (!special && (attr & attr_special)) {
+        return 0;
+    }
+
+    // copy piece chars to output text buffer
+    // skip up to 'lstrip' leading spaces before copying
+    auto _try_copy = [=] (const char * token, size_t size) -> int32_t {
+        if (size >= static_cast(std::numeric_limits::max())) {
+            GGML_ABORT("invalid token size: %zu exceeds int32_t limit", size);
+        }
+
+        for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) {
+            token++;
+            size--;
+        }
+        if (length < (int32_t)size) {
+            return -(int32_t) size;
+        }
+        memcpy(buf, token, size);
+        return (int32_t) size;
+    };
+
+    // if we have a cache - use it
+    {
+        const auto & cache = cache_token_to_piece;
+
+        if (!cache.empty()) {
+            const auto & result = cache.at(token);
+            return _try_copy(result.data(), result.size());
+        }
+    }
+
+    if (0 <= token && token < (int32_t) id_to_token.size()) {
+        const std::string & token_text = id_to_token[token].text;
+        switch (get_type()) {
+            case LLAMA_VOCAB_TYPE_WPM:
+            case LLAMA_VOCAB_TYPE_SPM:
+            case LLAMA_VOCAB_TYPE_UGM: {
+                // NOTE: we accept all unsupported token types,
+                // suppressing them like CONTROL tokens.
+                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
+                    return _try_copy(token_text.data(), token_text.size());
+                }
+                if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+                    std::string result = token_text;
+                    llama_unescape_whitespace(result);
+                    return _try_copy(result.data(), result.size());
+                }
+                if (attr & LLAMA_TOKEN_ATTR_BYTE) {
+                    char byte = (char) token_to_byte(token);
+                    return _try_copy((char*) &byte, 1);
+                }
+                break;
+            }
+            case LLAMA_VOCAB_TYPE_BPE: {
+                // NOTE: we accept all unsupported token types,
+                // suppressing them like CONTROL tokens.
+                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
+                    return _try_copy(token_text.data(), token_text.size());
+                }
+                if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+                    std::string result = llama_decode_text(token_text);
+                    return _try_copy(result.data(), result.size());
+                }
+                break;
+            }
+            case LLAMA_VOCAB_TYPE_RWKV: {
+                std::vector result = llama_unescape_rwkv_token(token_text);
+
+                // If we don't have enough space, return an error
+                if (result.size() > (size_t)length) {
+                    return -(int)result.size();
+                }
+
+                memcpy(buf, result.data(), result.size());
+                return (int)result.size();
+            }
+            case LLAMA_VOCAB_TYPE_PLAMO2: {
+                // PLaMo-2 uses similar token handling as BPE/SPM
+                if (vocab.is_byte(token)) {
+                    // Handle byte tokens like <0xXX>
+                    if (token_text.length() == 6 && token_text.substr(0, 3) == "<0x" && token_text.back() == '>') {
+                        int hex_val = std::stoi(token_text.substr(3, 2), nullptr, 16);
+                        if (length < 1) {
+                            return -1;
+                        }
+                        buf[0] = static_cast(hex_val);
+                        return 1;
+                    }
+                }
+
+                // Normal token - just copy the text
+                std::string result = token_text;
+                return _try_copy(result.data(), result.size());
+            }
+            default:
+                GGML_ABORT("fatal error");
+        }
+    }
+
+    return 0;
+}
+
+const std::string & llama_vocab::impl::token_to_piece(llama_token token) const {
+    return cache_token_to_piece.at(token);
+}
+
+int32_t llama_vocab::impl::detokenize(
+               const llama_token * tokens,
+                         int32_t   n_tokens,
+                            char * text,
+                         int32_t   text_len_max,
+                            bool   remove_special,
+                            bool   unparse_special) const {
+    if (type == LLAMA_VOCAB_TYPE_NONE) {
+        return 0;
+    }
+
+    GGML_ASSERT(tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+
+    int32_t avail = text_len_max;
+    int32_t total = 0;
+
+    // remove the leading space
+    bool remove_space = add_space_prefix;
+
+    if (remove_special && add_bos) {
+        if (n_tokens > 0 && tokens[0] == special_bos_id) {
+            remove_space = false;
+            n_tokens--;
+            tokens++;
+        }
+    }
+
+    if (remove_special && add_eos) {
+        if (n_tokens > 0 && tokens[n_tokens - 1] == special_eos_id) {
+            n_tokens--;
+        }
+    }
+
+    for (int32_t i = 0; i < n_tokens; ++i) {
+        GGML_ASSERT(avail >= 0);
+        int32_t n_chars = token_to_piece(tokens[i], text, avail, remove_space, unparse_special);
+        remove_space = false;
+        if (n_chars < 0) {
+            avail = 0;
+            total -= n_chars;
+        } else if (n_chars > 0) {
+            avail -= n_chars;
+            text  += n_chars;
+            total += n_chars;
+        }
+    }
+
+    if (total > text_len_max) {
+        return -total;
+    }
+
+    if (clean_spaces) {
+        text -= total;  // restart text
+
+        // first pass: characters ?!.,  //TODO: where do these characters come from?
+        const int32_t total1 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total1; ++i) {
+            const char x = text[i];
+            if (text[i - 1] == ' ') {
+                if (x == '?' || x == '!' || x == '.' || x == ',') {  // " ?", " !", " .", " ,"
+                    total--;  // remove space
+                }
+            }
+            text[total++] = x;
+        }
+
+        // second pass: strip single apostrophe between spaces
+        const int32_t total2 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total2; ++i) {
+            const char x = text[i];
+            if (x == '\'' && i + 1 < total2 && text[i - 1] == ' ' && text[i + 1] == ' ') {  // " ' "
+                total--;           // remove prev space
+                text[++i] = '\0';  // remove next space
+            }
+            text[total++] = x;
+        }
+
+        // third pass: apostrophe contractions  //NOTE: this makes sense?
+        const int32_t total3 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total3; ++i) {
+            const char x = text[i];
+            if (text[i - 1] == ' ') {
+                if (x == '\'' && i + 1 < total3) {
+                    const char x1 = text[i + 1];
+                    if (x1 == 't' || x1 == 'd') {  // " 't", " 'd"
+                        //total--;  // remove space
+                    } else if (x1 == 's' || x1 == 'm') {  // " 's", " 'm"
+                        total--;  // remove space
+                    } else if (i + 2 < total3) {
+                        const char x2 = text[i + 2];
+                        if ((x1 == 'l' && x2 == 'l')) {  // " 'll"
+                            //total--;  // remove space
+                        } else if ((x1 == 'r' && x2 == 'e') || (x1 == 'v' && x2 == 'e')) {  // " 're", " 've"
+                            total--;  // remove space
+                        } else {
+                            //total--;  // remove space
+                        }
+                    } else {
+                        //total--;  // remove space
+                    }
+                }
+            }
+            text[total++] = x;
+        }
+    }
+
+    return total <= text_len_max ? total : -total;
+}
+
+void llama_vocab::impl::print_info() const {
+    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, type_name().c_str());
+    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, vocab.n_tokens());
+    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (uint32_t) bpe_ranks.size());
+
+    // special tokens
+    if (special_bos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, special_bos_id,     id_to_token.at(special_bos_id).text.c_str() );  }
+    if (special_eos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, special_eos_id,     id_to_token.at(special_eos_id).text.c_str() );  }
+    if (special_eot_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, special_eot_id,     id_to_token.at(special_eot_id).text.c_str() );  }
+    if (special_eom_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, special_eom_id,     id_to_token.at(special_eom_id).text.c_str() );  }
+    if (special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, special_unk_id,     id_to_token.at(special_unk_id).text.c_str() );  }
+    if (special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, special_sep_id,     id_to_token.at(special_sep_id).text.c_str() );  }
+    if (special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, special_pad_id,     id_to_token.at(special_pad_id).text.c_str() );  }
+    if (special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, special_mask_id,    id_to_token.at(special_mask_id).text.c_str() ); }
+
+    if (linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, linefeed_id,        id_to_token.at(linefeed_id).text.c_str() ); }
+
+    if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); }
+    if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); }
+    if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); }
+    if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); }
+    if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); }
+    if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); }
+
+    for (const auto & id : special_eog_ids) {
+        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() );
+    }
+
+    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
+}
+
+llama_vocab::llama_vocab() : pimpl(new impl(*this)) {
+}
+
+llama_vocab::~llama_vocab() {
+}
+
+void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
+    pimpl->load(ml, kv);
+}
+
+std::string llama_vocab::get_tokenizer_model() const {
+    return pimpl->tokenizer_model;
+}
+
+std::string llama_vocab::get_tokenizer_pre() const {
+    return pimpl->tokenizer_pre;
+}
+
+enum llama_vocab_type llama_vocab::get_type() const {
+    return pimpl->type;
+}
+
+enum llama_vocab_pre_type llama_vocab::get_pre_type() const {
+    return pimpl->pre_type;
+}
+
+uint32_t llama_vocab::n_tokens() const {
+    return (uint32_t) pimpl->id_to_token.size();
+}
+
+uint32_t llama_vocab::n_token_types() const {
+    return (uint32_t) pimpl->n_token_types;
+}
+
+std::string llama_vocab::type_name() const{
+    return pimpl->type_name();
+}
+
+bool llama_vocab::is_normal(llama_token id) const {
+    return pimpl->is_normal(id);
+}
+
+bool llama_vocab::is_unknown(llama_token id) const {
+    return pimpl->is_unknown(id);
+}
+
+bool llama_vocab::is_control(llama_token id) const {
+    return pimpl->is_control(id);
+}
+
+bool llama_vocab::is_byte(llama_token id) const {
+    return pimpl->is_byte(id);
+}
+
+bool llama_vocab::is_user_defined(llama_token id) const {
+    return pimpl->is_user_defined(id);
+}
+
+bool llama_vocab::is_unused(llama_token id) const {
+    return pimpl->is_unused(id);
+}
+
+bool llama_vocab::is_eog(llama_token id) const {
+    return pimpl->is_eog(id);
+}
+
+uint8_t llama_vocab::token_to_byte(llama_token id) const {
+    return pimpl->token_to_byte(id);
+}
+
+llama_token llama_vocab::byte_to_token(uint8_t ch) const {
+    GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
+    static const char * hex = "0123456789ABCDEF";
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
+            const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
+            auto token = pimpl->token_to_id.find(buf);
+            if (token != pimpl->token_to_id.end()) {
+                return (*token).second;
+            }
+            // Try to fall back to just the byte as a string
+            const char buf2[2] = { (char)ch, 0 };
+            return pimpl->token_to_id.at(buf2);
+        }
+        case LLAMA_VOCAB_TYPE_WPM:
+        case LLAMA_VOCAB_TYPE_BPE: {
+            return pimpl->token_to_id.at(unicode_byte_to_utf8(ch));
+        }
+        case LLAMA_VOCAB_TYPE_PLAMO2: {
+            // PLaMo-2 uses byte tokens in format <0xXX>
+            char hex_str[8];
+            snprintf(hex_str, sizeof(hex_str), "<0x%02X>", ch);
+            return pimpl->token_to_id.at(hex_str);
+        }
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+llama_token llama_vocab::text_to_token(const std::string & text) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    auto it = pimpl->token_to_id.find(text);
+    if (it != pimpl->token_to_id.end()) {
+        return (*it).second;
+    }
+    return LLAMA_TOKEN_NULL;
+}
+
+const llama_vocab::token_data & llama_vocab::get_token_data(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id);
+}
+
+const char * llama_vocab::token_get_text(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id).text.c_str();
+}
+
+float llama_vocab::token_get_score(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id).score;
+}
+
+llama_token_attr llama_vocab::token_get_attr(llama_token id) const {
+    return pimpl->token_get_attr(id);
+}
+
+llama_token llama_vocab::token_bos() const {
+    return pimpl->special_bos_id;
+}
+
+llama_token llama_vocab::token_eos() const {
+    return pimpl->special_eos_id;
+}
+
+llama_token llama_vocab::token_eot() const {
+    return pimpl->special_eot_id;
+}
+
+llama_token llama_vocab::token_eom() const {
+    return pimpl->special_eom_id;
+}
+
+llama_token llama_vocab::token_unk() const {
+    return pimpl->special_unk_id;
+}
+
+llama_token llama_vocab::token_sep() const {
+    return pimpl->special_sep_id;
+}
+
+llama_token llama_vocab::token_nl() const {
+    return pimpl->linefeed_id;
+}
+
+llama_token llama_vocab::token_pad() const {
+    return pimpl->special_pad_id;
+}
+
+llama_token llama_vocab::token_prefix() const {
+    return pimpl->special_fim_pre_id;
+}
+
+llama_token llama_vocab::token_middle() const {
+    return pimpl->special_fim_mid_id;
+}
+
+llama_token llama_vocab::token_suffix() const {
+    return pimpl->special_fim_suf_id;
+}
+
+llama_token llama_vocab::token_fim_pre() const {
+    return pimpl->special_fim_pre_id;
+}
+
+llama_token llama_vocab::token_fim_suf() const {
+    return pimpl->special_fim_suf_id;
+}
+
+llama_token llama_vocab::token_fim_mid() const {
+    return pimpl->special_fim_mid_id;
+}
+
+llama_token llama_vocab::token_fim_pad() const {
+    return pimpl->special_fim_pad_id;
+}
+
+llama_token llama_vocab::token_fim_rep() const {
+    return pimpl->special_fim_rep_id;
+}
+
+llama_token llama_vocab::token_fim_sep() const {
+    return pimpl->special_fim_sep_id;
+}
+
+llama_token llama_vocab::token_mask() const {
+    return pimpl->special_mask_id;
+}
+
+bool llama_vocab::get_add_space_prefix() const {
+    return pimpl->add_space_prefix;
+}
+
+bool llama_vocab::get_add_bos() const {
+    return pimpl->add_bos;
+}
+
+bool llama_vocab::get_add_eos() const {
+    return pimpl->add_eos;
+}
+
+bool llama_vocab::get_add_sep() const {
+    return pimpl->add_sep;
+}
+
+bool llama_vocab::get_ignore_merges() const {
+    return pimpl->ignore_merges;
+}
+
+bool llama_vocab::get_clean_spaces() const {
+    return pimpl->clean_spaces;
+}
+
+bool llama_vocab::get_remove_extra_whitespaces() const {
+    return pimpl->remove_extra_whitespaces;
+}
+
+bool llama_vocab::get_escape_whitespaces() const {
+    return pimpl->escape_whitespaces;
+}
+
+bool llama_vocab::get_treat_whitespace_as_suffix() const {
+    return pimpl->treat_whitespace_as_suffix;
+}
+
+int llama_vocab::max_token_len() const {
+    return pimpl->max_token_len;
+}
+
+int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
+    GGML_ASSERT(token_left.find(' ')   == std::string::npos);
+    GGML_ASSERT(token_left.find('\n')  == std::string::npos);
+    GGML_ASSERT(token_right.find(' ')  == std::string::npos);
+    GGML_ASSERT(token_right.find('\n') == std::string::npos);
+
+    auto it = pimpl->bpe_ranks.find(std::make_pair(token_left, token_right));
+    if (it == pimpl->bpe_ranks.end()) {
+        return -1;
+    }
+
+    return it->second;
+}
+
+std::vector llama_vocab::get_bpe_merges() const {
+    std::vector result(pimpl->bpe_ranks.size());
+
+    for (const auto & pair : pimpl->bpe_ranks) {
+        result[pair.second] = pair.first.first + " " + pair.first.second;
+    }
+
+    return result;
+}
+
+std::vector llama_vocab::get_precompiled_charsmap() const {
+    return pimpl->precompiled_charsmap;
+}
+
+int32_t llama_vocab::tokenize(
+                  const char * text,
+                     int32_t   text_len,
+                 llama_token * tokens,
+                     int32_t   n_tokens_max,
+                        bool   add_special,
+                        bool   parse_special) const {
+    auto res = tokenize(std::string(text, text_len), add_special, parse_special);
+    if (res.size() >= static_cast(std::numeric_limits::max())) {
+        LLAMA_LOG_ERROR("%s: tokenization result size %zu exceeds int32_t limit\n", __func__, res.size());
+        return std::numeric_limits::min();
+    }
+
+    if (n_tokens_max < (int) res.size()) {
+        // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
+        return -((int) res.size());
+    }
 
-                if (add_special && vocab.tokenizer_add_eos == 1) {
-                    GGML_ASSERT(vocab.special_eos_id != -1);
-                    output.push_back(vocab.special_eos_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_NONE:
-            GGML_ABORT("fatal error");
+    for (size_t i = 0; i < res.size(); i++) {
+        tokens[i] = res[i];
     }
 
-    return output;
+    return res.size();
 }
 
-llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch) {
-    GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
-    static const char * hex = "0123456789ABCDEF";
-    switch (llama_vocab_get_type(vocab)) {
-        case LLAMA_VOCAB_TYPE_SPM:
-        case LLAMA_VOCAB_TYPE_UGM: {
-            const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
-            auto token = vocab.token_to_id.find(buf);
-            if (token != vocab.token_to_id.end()) {
-                return (*token).second;
-            }
-            // Try to fall back to just the byte as a string
-            const char buf2[2] = { (char)ch, 0 };
-            return vocab.token_to_id.at(buf2);
-        }
-        case LLAMA_VOCAB_TYPE_WPM:
-        case LLAMA_VOCAB_TYPE_BPE: {
-            return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
-        }
-        default:
-            GGML_ABORT("fatal error");
-    }
+std::vector llama_vocab::tokenize(
+        const std::string & raw_text,
+        bool add_special,
+        bool parse_special) const {
+    return pimpl->tokenize(raw_text, add_special, parse_special);
 }
 
-const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].text.c_str();
+const std::string & llama_vocab::token_to_piece(llama_token token) const {
+    return pimpl->token_to_piece(token);
 }
 
-float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].score;
+int32_t llama_vocab::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
+    return pimpl->token_to_piece(token, buf, length, lstrip, special);
 }
 
-llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].attr;
+int32_t llama_vocab::detokenize(
+               const llama_token * tokens,
+                         int32_t   n_tokens,
+                            char * text,
+                         int32_t   text_len_max,
+                            bool   remove_special,
+                            bool   unparse_special) const {
+    return pimpl->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
 }
 
-bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
-    return token != -1 && (
-        token == llama_token_eos_impl(vocab) ||
-        token == llama_token_eot_impl(vocab) ||
-        token == llama_token_eom_impl(vocab)
-    );
+std::string llama_vocab::detokenize(const std::vector & tokens, bool special) const {
+    std::string text;
+    text.resize(std::max(text.capacity(), tokens.size()));
+    int32_t n_chars = detokenize(tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+    if (n_chars < 0) {
+        text.resize(-n_chars);
+        n_chars = detokenize(tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+        GGML_ASSERT(n_chars <= (int32_t)text.size());  // whitespace trimming is performed after per-token detokenization
+    }
+
+    text.resize(n_chars);
+
+    // NOTE: the original tokenizer decodes bytes after collecting the pieces.
+    return text;
 }
 
-bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
-    return llama_is_control_token(vocab, token);
+void llama_vocab::print_info() const {
+    pimpl->print_info();
 }
 
-llama_token llama_token_bos_impl(const struct llama_vocab & vocab) {
-    return vocab.special_bos_id;
+//
+// interface implementation
+//
+
+int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab) {
+    return vocab->n_tokens();
 }
 
-llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eos_id;
+// deprecated
+int32_t llama_n_vocab(const struct llama_vocab * vocab) {
+    return llama_vocab_n_tokens(vocab);
 }
 
-llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
-    return vocab.special_cls_id;
+enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab) {
+    return vocab->get_type();
 }
 
-llama_token llama_token_sep_impl(const struct llama_vocab & vocab) {
-    return vocab.special_sep_id;
+const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_text(token);
 }
 
-llama_token llama_token_nl_impl(const struct llama_vocab & vocab) {
-    return vocab.linefeed_id;
+float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_score(token);
 }
 
-llama_token llama_token_pad_impl(const struct llama_vocab & vocab) {
-    return vocab.special_pad_id;
+enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_attr(token);
 }
 
-int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab) {
-    return vocab.tokenizer_add_bos;
+bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->is_eog(token);
 }
 
-int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab) {
-    return vocab.tokenizer_add_eos;
+bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->is_control(token);
 }
 
-llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
-    return vocab.special_prefix_id;
+llama_token llama_vocab_bos(const struct llama_vocab * vocab) {
+    return vocab->token_bos();
 }
 
-llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
-    return vocab.special_middle_id;
+llama_token llama_vocab_eos(const struct llama_vocab * vocab) {
+    return vocab->token_eos();
 }
 
-llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
-    return vocab.special_suffix_id;
+llama_token llama_vocab_eot(const struct llama_vocab * vocab) {
+    return vocab->token_eot();
 }
 
-llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_pre_id;
+// deprecated
+llama_token llama_vocab_cls(const struct llama_vocab * vocab) {
+    return vocab->token_bos();
 }
 
-llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_suf_id;
+llama_token llama_vocab_sep(const struct llama_vocab * vocab) {
+    return vocab->token_sep();
 }
 
-llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_mid_id;
+llama_token llama_vocab_nl (const struct llama_vocab * vocab) {
+    return vocab->token_nl();
 }
 
-llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_pad_id;
+llama_token llama_vocab_pad(const struct llama_vocab * vocab) {
+    return vocab->token_pad();
 }
 
-llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_rep_id;
+bool llama_vocab_get_add_bos(const struct llama_vocab * vocab) {
+    return vocab->get_add_bos();
 }
 
-llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_sep_id;
+bool llama_vocab_get_add_eos(const struct llama_vocab * vocab) {
+    return vocab->get_add_eos();
 }
 
-llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eot_id;
+bool llama_vocab_get_add_sep(const struct llama_vocab * vocab) {
+    return vocab->get_add_sep();
 }
 
-llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eom_id;
+llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab) {
+    return vocab->token_fim_pre();
 }
 
-int32_t llama_tokenize_impl(
-    const struct llama_vocab & vocab,
-                  const char * text,
-                     int32_t   text_len,
-                 llama_token * tokens,
-                     int32_t   n_tokens_max,
-                        bool   add_special,
-                        bool   parse_special) {
-    auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special);
-    if (n_tokens_max < (int) res.size()) {
-        // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
-        return -((int) res.size());
-    }
+llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab) {
+    return vocab->token_fim_suf();
+}
 
-    for (size_t i = 0; i < res.size(); i++) {
-        tokens[i] = res[i];
-    }
+llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab) {
+    return vocab->token_fim_mid();
+}
 
-    return res.size();
+llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab) {
+    return vocab->token_fim_pad();
 }
 
-static std::string llama_decode_text(const std::string & text) {
-    std::string decoded_text;
+llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab) {
+    return vocab->token_fim_rep();
+}
 
-    const auto cpts = unicode_cpts_from_utf8(text);
-    for (const auto cpt : cpts) {
-        const auto utf8 = unicode_cpt_to_utf8(cpt);
-        try {
-            decoded_text += unicode_utf8_to_byte(utf8);
-        } catch (const std::out_of_range & /*e*/) {
-            decoded_text += "[UNK_BYTE_0x";
-            for (const auto c : utf8) {
-                decoded_text += format("%02x", (uint8_t) c);
-            }
-            decoded_text += text + "]";
-        }
-    }
+llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) {
+    return vocab->token_fim_sep();
+}
 
-    return decoded_text;
+llama_token llama_vocab_mask(const struct llama_vocab* vocab) {
+    return vocab->token_mask();
 }
 
-// does not write null-terminator to buf
-int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) {
-    // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
-    static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
-    const llama_token_attr attr = llama_token_get_attr_impl(vocab, token);
-    if (!special && (attr & attr_special)) {
-        return 0;
-    }
+// deprecated
+const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_text(vocab, token);
+}
 
-    // copy piece chars to output text buffer
-    // skip up to 'lstrip' leading spaces before copying
-    auto _try_copy = [=] (const char * token, size_t size) -> int32_t {
-        for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) {
-            token++;
-            size--;
-        }
-        if (length < (int32_t)size) {
-            return -(int32_t) size;
-        }
-        memcpy(buf, token, size);
-        return (int32_t) size;
-    };
+// deprecated
+float llama_token_get_score(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_score(vocab, token);
+}
 
-    // if we have a cache - use it
-    {
-        const auto & cache = vocab.cache_token_to_piece;
+// deprecated
+enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_attr(vocab, token);
+}
 
-        if (!cache.empty()) {
-            const auto & result = cache.at(token);
-            return _try_copy(result.data(), result.size());
-        }
-    }
+// deprecated
+bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_is_eog(vocab, token);
+}
 
-    if (0 <= token && token < (int32_t) vocab.id_to_token.size()) {
-        const std::string & token_text = vocab.id_to_token[token].text;
-        switch (llama_vocab_get_type(vocab)) {
-            case LLAMA_VOCAB_TYPE_WPM:
-            case LLAMA_VOCAB_TYPE_SPM:
-            case LLAMA_VOCAB_TYPE_UGM: {
-                // NOTE: we accept all unsupported token types,
-                // suppressing them like CONTROL tokens.
-                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
-                    return _try_copy(token_text.data(), token_text.size());
-                } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
-                    std::string result = token_text;
-                    llama_unescape_whitespace(result);
-                    return _try_copy(result.data(), result.size());
-                } else if (attr & LLAMA_TOKEN_ATTR_BYTE) {
-                    char byte = (char) llama_token_to_byte(vocab, token);
-                    return _try_copy((char*) &byte, 1);
-                }
-                break;
-            }
-            case LLAMA_VOCAB_TYPE_BPE: {
-                // NOTE: we accept all unsupported token types,
-                // suppressing them like CONTROL tokens.
-                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
-                    return _try_copy(token_text.data(), token_text.size());
-                } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
-                    std::string result = llama_decode_text(token_text);
-                    return _try_copy(result.data(), result.size());
-                }
-                break;
-            }
-            default:
-                GGML_ABORT("fatal error");
-        }
-    }
+// deprecated
+bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_is_control(vocab, token);
+}
 
-    return 0;
+// deprecated
+llama_token llama_token_bos(const struct llama_vocab * vocab) {
+    return llama_vocab_bos(vocab);
 }
 
-int32_t llama_detokenize_impl(
-        const struct llama_vocab & vocab,
-               const llama_token * tokens,
-                         int32_t   n_tokens,
-                            char * text,
-                         int32_t   text_len_max,
-                            bool   remove_special,
-                            bool   unparse_special) {
-    int32_t avail = text_len_max;
-    int32_t total = 0;
+// deprecated
+llama_token llama_token_eos(const struct llama_vocab * vocab) {
+    return llama_vocab_eos(vocab);
+}
 
-    // remove the leading space
-    bool remove_space = vocab.tokenizer_add_space_prefix;
+// deprecated
+llama_token llama_token_eot(const struct llama_vocab * vocab) {
+    return llama_vocab_eot(vocab);
+}
 
-    if (remove_special && vocab.tokenizer_add_bos) {
-        if (n_tokens > 0 && tokens[0] == vocab.special_bos_id) {
-            remove_space = false;
-            n_tokens--;
-            tokens++;
-        }
-    }
+// deprecated
+llama_token llama_token_cls(const struct llama_vocab * vocab) {
+    //return llama_vocab_cls(vocab);
+    return llama_vocab_bos(vocab); // avoid deprecation warning
+}
 
-    if (remove_special && vocab.tokenizer_add_eos) {
-        if (n_tokens > 0 && tokens[n_tokens-1] == vocab.special_eos_id) {
-            n_tokens--;
-        }
-    }
+// deprecated
+llama_token llama_token_sep(const struct llama_vocab * vocab) {
+    return llama_vocab_sep(vocab);
+}
 
-    for (int32_t i = 0; i < n_tokens; ++i) {
-        GGML_ASSERT(avail >= 0);
-        int32_t n_chars = llama_token_to_piece_impl(vocab, tokens[i], text, avail, remove_space, unparse_special);
-        remove_space = false;
-        if (n_chars < 0) {
-            avail = 0;
-            total -= n_chars;
-        } else if (n_chars > 0) {
-            avail -= n_chars;
-            text  += n_chars;
-            total += n_chars;
-        }
-    }
+// deprecated
+llama_token llama_token_nl (const struct llama_vocab * vocab) {
+    return llama_vocab_nl(vocab);
+}
 
-    if (total > text_len_max) {
-        return -total;
-    }
+// deprecated
+llama_token llama_token_pad(const struct llama_vocab * vocab) {
+    return llama_vocab_pad(vocab);
+}
 
-    if (vocab.tokenizer_clean_spaces) {
-        text -= total;  // restart text
+// deprecated
+bool llama_add_bos_token(const struct llama_vocab * vocab) {
+    return llama_vocab_get_add_bos(vocab);
+}
 
-        // first pass: characters ?!.,  //TODO: where do these characters come from?
-        const int32_t total1 = total;
-        total = total ? 1 : 0;
-        for (int32_t i = 1; i < total1; ++i) {
-            const char x = text[i];
-            if (text[i - 1] == ' ') {
-                if (x == '?' || x == '!' || x == '.' || x == ',') {  // " ?", " !", " .", " ,"
-                    total--;  // remove space
-                }
-            }
-            text[total++] = x;
-        }
+// deprecated
+bool llama_add_eos_token(const struct llama_vocab * vocab) {
+    return llama_vocab_get_add_eos(vocab);
+}
 
-        // second pass: strip single apostrophe between spaces
-        const int32_t total2 = total;
-        total = total ? 1 : 0;
-        for (int32_t i = 1; i < total2; ++i) {
-            const char x = text[i];
-            if (x == '\'' && i + 1 < total2 && text[i - 1] == ' ' && text[i + 1] == ' ') {  // " ' "
-                total--;           // remove prev space
-                text[++i] = '\0';  // remove next space
-            }
-            text[total++] = x;
-        }
+// deprecated
+llama_token llama_token_fim_pre(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_pre(vocab);
+}
 
-        // third pass: apostrophe contractions  //NOTE: this makes sense?
-        const int32_t total3 = total;
-        total = total ? 1 : 0;
-        for (int32_t i = 1; i < total3; ++i) {
-            const char x = text[i];
-            if (text[i - 1] == ' ') {
-                if (x == '\'' && i + 1 < total3) {
-                    const char x1 = text[i + 1];
-                    if (x1 == 't' || x1 == 'd') {  // " 't", " 'd"
-                        //total--;  // remove space
-                    } else if (x1 == 's' || x1 == 'm') {  // " 's", " 'm"
-                        total--;  // remove space
-                    } else if (i + 2 < total3) {
-                        const char x2 = text[i + 2];
-                        if ((x1 == 'l' && x2 == 'l')) {  // " 'll"
-                            //total--;  // remove space
-                        } else if ((x1 == 'r' && x2 == 'e') || (x1 == 'v' && x2 == 'e')) {  // " 're", " 've"
-                            total--;  // remove space
-                        } else {
-                            //total--;  // remove space
-                        }
-                    } else {
-                        //total--;  // remove space
-                    }
-                }
-            }
-            text[total++] = x;
-        }
-    }
+// deprecated
+llama_token llama_token_fim_suf(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_suf(vocab);
+}
 
-    return total <= text_len_max ? total : -total;
+// deprecated
+llama_token llama_token_fim_mid(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_mid(vocab);
 }
 
-std::string llama_detokenize(const struct llama_vocab& vocab, const std::vector& tokens, bool special) {
-    std::string text;
-    text.resize(std::max(text.capacity(), tokens.size()));
-    int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
-    if (n_chars < 0) {
-        text.resize(-n_chars);
-        n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
-        GGML_ASSERT(n_chars <= (int32_t)text.size());  // whitespace trimming is performed after per-token detokenization
-    }
+// deprecated
+llama_token llama_token_fim_pad(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_pad(vocab);
+}
 
-    text.resize(n_chars);
+// deprecated
+llama_token llama_token_fim_rep(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_rep(vocab);
+}
 
-    // NOTE: the original tokenizer decodes bytes after collecting the pieces.
-    return text;
+// deprecated
+llama_token llama_token_fim_sep(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_sep(vocab);
+}
+
+//
+// tokenization
+//
+
+int32_t llama_tokenize(
+    const struct llama_vocab * vocab,
+                  const char * text,
+                     int32_t   text_len,
+                 llama_token * tokens,
+                     int32_t   n_tokens_max,
+                        bool   add_special,
+                        bool   parse_special) {
+    return vocab->tokenize(text, text_len, tokens, n_tokens_max, add_special, parse_special);
+}
+
+int32_t llama_token_to_piece(
+    const struct llama_vocab * vocab,
+                 llama_token   token,
+                        char * buf,
+                     int32_t   length,
+                     int32_t   lstrip,
+                        bool   special) {
+    return vocab->token_to_piece(token, buf, length, lstrip, special);
+}
+
+int32_t llama_detokenize(
+    const struct llama_vocab * vocab,
+           const llama_token * tokens,
+                     int32_t   n_tokens,
+                        char * text,
+                     int32_t   text_len_max,
+                        bool   remove_special,
+                        bool   unparse_special) {
+    return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
 }
diff --git a/src/llama-vocab.h b/src/llama-vocab.h
index 64ff7cc08..6a37b0c18 100644
--- a/src/llama-vocab.h
+++ b/src/llama-vocab.h
@@ -1,155 +1,178 @@
 #pragma once
 
-#include "llama-impl.h"
+#include "llama.h"
 
 #include 
 #include 
-#include 
-#include 
+#include 
+
+// pre-tokenization types
+enum llama_vocab_pre_type {
+    LLAMA_VOCAB_PRE_TYPE_DEFAULT        = 0,
+    LLAMA_VOCAB_PRE_TYPE_LLAMA3         = 1,
+    LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM   = 2,
+    LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
+    LLAMA_VOCAB_PRE_TYPE_FALCON         = 4,
+    LLAMA_VOCAB_PRE_TYPE_MPT            = 5,
+    LLAMA_VOCAB_PRE_TYPE_STARCODER      = 6,
+    LLAMA_VOCAB_PRE_TYPE_GPT2           = 7,
+    LLAMA_VOCAB_PRE_TYPE_REFACT         = 8,
+    LLAMA_VOCAB_PRE_TYPE_COMMAND_R      = 9,
+    LLAMA_VOCAB_PRE_TYPE_STABLELM2      = 10,
+    LLAMA_VOCAB_PRE_TYPE_QWEN2          = 11,
+    LLAMA_VOCAB_PRE_TYPE_OLMO           = 12,
+    LLAMA_VOCAB_PRE_TYPE_DBRX           = 13,
+    LLAMA_VOCAB_PRE_TYPE_SMAUG          = 14,
+    LLAMA_VOCAB_PRE_TYPE_PORO           = 15,
+    LLAMA_VOCAB_PRE_TYPE_CHATGLM3       = 16,
+    LLAMA_VOCAB_PRE_TYPE_CHATGLM4       = 17,
+    LLAMA_VOCAB_PRE_TYPE_VIKING         = 18,
+    LLAMA_VOCAB_PRE_TYPE_JAIS           = 19,
+    LLAMA_VOCAB_PRE_TYPE_TEKKEN         = 20,
+    LLAMA_VOCAB_PRE_TYPE_SMOLLM         = 21,
+    LLAMA_VOCAB_PRE_TYPE_CODESHELL      = 22,
+    LLAMA_VOCAB_PRE_TYPE_BLOOM          = 23,
+    LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH   = 24,
+    LLAMA_VOCAB_PRE_TYPE_EXAONE         = 25,
+    LLAMA_VOCAB_PRE_TYPE_CHAMELEON      = 26,
+    LLAMA_VOCAB_PRE_TYPE_MINERVA        = 27,
+    LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM  = 28,
+    LLAMA_VOCAB_PRE_TYPE_GPT4O          = 29,
+    LLAMA_VOCAB_PRE_TYPE_SUPERBPE       = 30,
+    LLAMA_VOCAB_PRE_TYPE_TRILLION       = 31,
+    LLAMA_VOCAB_PRE_TYPE_BAILINGMOE     = 32,
+    LLAMA_VOCAB_PRE_TYPE_LLAMA4         = 33,
+    LLAMA_VOCAB_PRE_TYPE_PIXTRAL        = 34,
+    LLAMA_VOCAB_PRE_TYPE_SEED_CODER     = 35,
+    LLAMA_VOCAB_PRE_TYPE_HUNYUAN        = 36,
+    LLAMA_VOCAB_PRE_TYPE_KIMI_K2        = 37,
+    LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE  = 38,
+};
 
-struct llama_vocab {
-    using id    = llama_token;
-    using token = std::string;
-    using tattr = llama_token_attr;
+struct LLM_KV;
+struct llama_model_loader;
 
+struct llama_vocab {
     struct token_data {
-        token text;
-        float score;
-        tattr attr;
+        std::string      text;
+        float            score;
+        llama_token_attr attr;
     };
 
-    enum llama_vocab_type     type     = LLAMA_VOCAB_TYPE_SPM;
-    enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+    llama_vocab();
+    ~llama_vocab();
 
-    int max_token_len = 0; // used for optimizing longest token search
+    void load(llama_model_loader & ml, const LLM_KV & kv);
 
-    uint32_t n_tokens() const;
+    std::string get_tokenizer_model() const;
+    std::string get_tokenizer_pre() const;
 
-    std::unordered_map token_to_id;
-    std::vector       id_to_token;
-
-    std::vector    cache_special_tokens;
-    std::vector cache_token_to_piece; // llama_token_to_piece(special = true);
-
-    std::map, int> bpe_ranks;
-
-    // default LLaMA special tokens
-    id special_bos_id  = 1;
-    id special_eos_id  = 2;
-    id special_unk_id  = 0;
-    id special_sep_id  = -1;
-    id special_pad_id  = -1;
-    id special_cls_id  = -1;
-    id special_mask_id = -1;
-
-    id linefeed_id       = 13;
-
-    // fim tokens
-    llama_token special_fim_pre_id = -1;
-    llama_token special_fim_suf_id = -1;
-    llama_token special_fim_mid_id = -1;
-    llama_token special_fim_pad_id = -1;
-    llama_token special_fim_rep_id = -1; // repo
-    llama_token special_fim_sep_id = -1; // file separator
-
-    id special_prefix_id = -1;
-    id special_suffix_id = -1;
-    id special_middle_id = -1;
-    id special_eot_id    = -1; // TODO: move above after "eos_id", and here add "file separator" token
-    id special_eom_id    = -1;
-
-    // tokenizer flags
-    bool tokenizer_add_space_prefix = false;
-    bool tokenizer_add_bos          = false;
-    bool tokenizer_add_eos          = false;
-    bool tokenizer_ignore_merges    = false;
-    bool tokenizer_clean_spaces     = false;  // clean_up_tokenization_spaces
-    bool tokenizer_remove_extra_whitespaces   = false;
-    bool tokenizer_escape_whitespaces         = true;
-    bool tokenizer_treat_whitespace_as_suffix = false;
-
-    std::vector precompiled_charsmap;
+    enum llama_vocab_type     get_type()     const;
+    enum llama_vocab_pre_type get_pre_type() const;
+
+    uint32_t n_tokens() const;
+    uint32_t n_token_types() const;
+
+    std::string type_name() const;
+
+    bool is_normal      (llama_token id) const;
+    bool is_unknown     (llama_token id) const;
+    bool is_control     (llama_token id) const;
+    bool is_byte        (llama_token id) const;
+    bool is_user_defined(llama_token id) const;
+    bool is_unused      (llama_token id) const;
+    bool is_eog         (llama_token id) const;
+
+    uint8_t     token_to_byte(llama_token id) const;
+    llama_token byte_to_token(uint8_t ch)     const;
+
+    llama_token text_to_token(const std::string & text) const;
+
+    const token_data & get_token_data(llama_token id) const;
+
+    const char *     token_get_text (llama_token id) const;
+    float            token_get_score(llama_token id) const;
+    llama_token_attr token_get_attr (llama_token id) const;
+
+    llama_token token_bos() const;
+    llama_token token_eos() const;
+    llama_token token_eot() const;
+    llama_token token_eom() const;
+    llama_token token_unk() const;
+    llama_token token_sep() const;
+    llama_token token_nl () const;
+    llama_token token_pad() const;
+    llama_token token_mask() const;
+
+    llama_token token_prefix() const;
+    llama_token token_middle() const;
+    llama_token token_suffix() const;
+
+    llama_token token_fim_pre() const;
+    llama_token token_fim_suf() const;
+    llama_token token_fim_mid() const;
+    llama_token token_fim_pad() const;
+    llama_token token_fim_rep() const;
+    llama_token token_fim_sep() const;
+
+    bool get_add_space_prefix          () const;
+    bool get_add_bos                   () const;
+    bool get_add_eos                   () const;
+    bool get_add_sep                   () const;
+    bool get_ignore_merges             () const;
+    bool get_clean_spaces              () const;
+    bool get_remove_extra_whitespaces  () const;
+    bool get_escape_whitespaces        () const;
+    bool get_treat_whitespace_as_suffix() const;
+
+    int max_token_len() const;
 
     int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
+    std::vector get_bpe_merges() const;
+
+    std::vector get_precompiled_charsmap() const;
+
+    int32_t tokenize(
+                   const char * text,
+                      int32_t   text_len,
+                  llama_token * tokens,
+                      int32_t   n_tokens_max,
+                         bool   add_special,
+                         bool   parse_special) const;
+
+    std::vector tokenize(
+            const std::string & raw_text,
+                         bool   add_special,
+                         bool   parse_special = false) const;
+
+    // does not write null-terminator to buf
+    int32_t token_to_piece(
+                  llama_token   token,
+                         char * buf,
+                      int32_t   length,
+                      int32_t   lstrip,
+                         bool   special) const;
+
+    // use cached data
+    const std::string & token_to_piece(llama_token token) const;
+
+    int32_t detokenize(
+            const llama_token * tokens,
+                      int32_t   n_tokens,
+                         char * text,
+                      int32_t   text_len_max,
+                         bool   remove_special,
+                         bool   unparse_special) const;
+
+    std::string detokenize(
+            const std::vector & tokens,
+                                      bool   special) const;
+
+    void print_info() const;
+
+private:
+    struct impl;
+    std::unique_ptr pimpl;
 };
 
 const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
-
-//
-// internal API
-//
-
-// TODO: rename to llama_tokenize_impl
-// TODO: This should probably be in llama.h
-std::vector llama_tokenize_internal(
-        const llama_vocab & vocab,
-        std::string raw_text,
-        bool add_special,
-        bool parse_special = false);
-
-llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
-
-const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
-
-float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token);
-
-llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token);
-
-bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token);
-
-bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token);
-
-llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
-llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
-llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
-llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
-llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
-llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
-
-int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab);
-int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab);
-
-llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab);
-
-llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
-llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
-llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
-llama_token llama_token_eot_impl   (const struct llama_vocab & vocab);
-llama_token llama_token_eom_impl   (const struct llama_vocab & vocab);
-
-int32_t llama_tokenize_impl(
-        const struct llama_vocab & vocab,
-                      const char * text,
-                         int32_t   text_len,
-                     llama_token * tokens,
-                         int32_t   n_tokens_max,
-                            bool   add_special,
-                            bool   parse_special);
-
-// does not write null-terminator to buf
-int32_t llama_token_to_piece_impl(
-        const struct llama_vocab & vocab,
-                     llama_token   token,
-                            char * buf,
-                         int32_t   length,
-                         int32_t   lstrip,
-                            bool   special);
-
-int32_t llama_detokenize_impl(
-        const struct llama_vocab & vocab,
-               const llama_token * tokens,
-                         int32_t   n_tokens,
-                            char * text,
-                         int32_t   text_len_max,
-                            bool   remove_special,
-                            bool   unparse_special);
-
-std::string llama_detokenize(
-    const struct llama_vocab& vocab,
-    const std::vector& tokens,
-    bool   special);
diff --git a/src/llama.cpp b/src/llama.cpp
index fb9331ac5..c5fa10264 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -9,6 +9,9 @@
 #include "llama-vocab.h"
 #include "llama-grammar.h"
 #include "llama-sampling.h"
+#include "llama-arch.h"
+#include "llama-mmap.h"
+#include "llama-model-loader.h"
 
 #include "unicode.h"
 
@@ -178,85 +181,10 @@ static void zeros(std::ofstream & file, size_t n) {
     }
 }
 
-LLAMA_ATTRIBUTE_FORMAT(1, 2)
-static std::string format(const char * fmt, ...) {
-    va_list ap;
-    va_list ap2;
-    va_start(ap, fmt);
-    va_copy(ap2, ap);
-    int size = vsnprintf(NULL, 0, fmt, ap);
-    GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
-    std::vector buf(size + 1);
-    int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
-    GGML_ASSERT(size2 == size);
-    va_end(ap2);
-    va_end(ap);
-    return std::string(buf.data(), size);
-}
-
 //
 // gguf constants (sync with gguf.py)
 //
 
-enum llm_arch {
-    LLM_ARCH_LLAMA,
-    LLM_ARCH_LLAMA4,
-    LLM_ARCH_DECI,
-    LLM_ARCH_FALCON,
-    LLM_ARCH_BAICHUAN,
-    LLM_ARCH_GROK,
-    LLM_ARCH_GPT2,
-    LLM_ARCH_GPTJ,
-    LLM_ARCH_GPTNEOX,
-    LLM_ARCH_MPT,
-    LLM_ARCH_STARCODER,
-    LLM_ARCH_REFACT,
-    LLM_ARCH_BERT,
-    LLM_ARCH_NOMIC_BERT,
-    LLM_ARCH_JINA_BERT_V2,
-    LLM_ARCH_BLOOM,
-    LLM_ARCH_STABLELM,
-    LLM_ARCH_QWEN,
-    LLM_ARCH_QWEN2,
-    LLM_ARCH_QWEN2MOE,
-    LLM_ARCH_QWEN3,
-    LLM_ARCH_QWEN3MOE,
-    LLM_ARCH_PHI2,
-    LLM_ARCH_PHI3,
-    LLM_ARCH_PLAMO,
-    LLM_ARCH_CODESHELL,
-    LLM_ARCH_ORION,
-    LLM_ARCH_INTERNLM2,
-    LLM_ARCH_MINICPM,
-    LLM_ARCH_GEMMA,
-    LLM_ARCH_GEMMA2,
-    LLM_ARCH_GEMMA3,
-    LLM_ARCH_STARCODER2,
-    LLM_ARCH_MAMBA,
-    LLM_ARCH_XVERSE,
-    LLM_ARCH_COMMAND_R,
-    LLM_ARCH_DBRX,
-    LLM_ARCH_OLMO,
-    LLM_ARCH_OPENELM,
-    LLM_ARCH_ARCTIC,
-    LLM_ARCH_DEEPSEEK2,
-    LLM_ARCH_CHATGLM,
-    LLM_ARCH_GLM4,
-    LLM_ARCH_GLM4_MOE,
-    LLM_ARCH_BITNET,
-    LLM_ARCH_BITNET_25,
-    LLM_ARCH_BITNET_B158,
-    LLM_ARCH_T5,
-    LLM_ARCH_T5ENCODER,
-    LLM_ARCH_JAIS,
-    LLM_ARCH_GRANITE,
-    LLM_ARCH_GRANITE_MOE,
-    LLM_ARCH_COHERE2,
-    LLM_ARCH_DOTS1,
-    LLM_ARCH_HUNYUAN_MOE,
-    LLM_ARCH_UNKNOWN,
-};
-
 static const std::map LLM_ARCH_NAMES = {
     { LLM_ARCH_LLAMA,           "llama"        },
     { LLM_ARCH_LLAMA4,          "llama4"       },
@@ -313,126 +241,19 @@ static const std::map LLM_ARCH_NAMES = {
     { LLM_ARCH_COHERE2,         "cohere2"      },
     { LLM_ARCH_DOTS1,            "dots1"       },
     { LLM_ARCH_HUNYUAN_MOE,     "hunyuan-moe"  },
+    { LLM_ARCH_OPENAI_MOE,      "gpt-oss"      },
     { LLM_ARCH_UNKNOWN,         "(unknown)"    },
 };
 
-enum llm_kv {
-    LLM_KV_GENERAL_TYPE,
-    LLM_KV_GENERAL_ARCHITECTURE,
-    LLM_KV_GENERAL_QUANTIZATION_VERSION,
-    LLM_KV_GENERAL_ALIGNMENT,
-    LLM_KV_GENERAL_NAME,
-    LLM_KV_GENERAL_AUTHOR,
-    LLM_KV_GENERAL_VERSION,
-    LLM_KV_GENERAL_URL,
-    LLM_KV_GENERAL_DESCRIPTION,
-    LLM_KV_GENERAL_LICENSE,
-    LLM_KV_GENERAL_SOURCE_URL,
-    LLM_KV_GENERAL_SOURCE_HF_REPO,
-
-    LLM_KV_VOCAB_SIZE,
-    LLM_KV_CONTEXT_LENGTH,
-    LLM_KV_EMBEDDING_LENGTH,
-    LLM_KV_BLOCK_COUNT,
-    LLM_KV_LEADING_DENSE_BLOCK_COUNT,
-    LLM_KV_FEED_FORWARD_LENGTH,
-    LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
-    LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
-    LLM_KV_USE_PARALLEL_RESIDUAL,
-    LLM_KV_TENSOR_DATA_LAYOUT,
-    LLM_KV_EXPERT_COUNT,
-    LLM_KV_EXPERT_USED_COUNT,
-    LLM_KV_EXPERT_SHARED_COUNT,
-    LLM_KV_EXPERT_WEIGHTS_SCALE,
-    LLM_KV_EXPERT_WEIGHTS_NORM,
-    LLM_KV_EXPERT_GATING_FUNC,
-    LLM_KV_NEXTN_PREDICT_LAYERS,
-    LLM_KV_POOLING_TYPE,
-    LLM_KV_LOGIT_SCALE,
-    LLM_KV_DECODER_START_TOKEN_ID,
-    LLM_KV_ATTN_LOGIT_SOFTCAPPING,
-    LLM_KV_FINAL_LOGIT_SOFTCAPPING,
-    LLM_KV_SWIN_NORM,
-    LLM_KV_RESCALE_EVERY_N_LAYERS,
-    LLM_KV_TIME_MIX_EXTRA_DIM,
-    LLM_KV_TIME_DECAY_EXTRA_DIM,
-    LLM_KV_RESIDUAL_SCALE,
-    LLM_KV_EMBEDDING_SCALE,
-    LLM_KV_TOKEN_SHIFT_COUNT,
-    LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
-
-    LLM_KV_ATTENTION_HEAD_COUNT,
-    LLM_KV_ATTENTION_HEAD_COUNT_KV,
-    LLM_KV_ATTENTION_MAX_ALIBI_BIAS,
-    LLM_KV_ATTENTION_CLAMP_KQV,
-    LLM_KV_ATTENTION_KEY_LENGTH,
-    LLM_KV_ATTENTION_VALUE_LENGTH,
-    LLM_KV_ATTENTION_LAYERNORM_EPS,
-    LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
-    LLM_KV_ATTENTION_CAUSAL,
-    LLM_KV_ATTENTION_Q_LORA_RANK,
-    LLM_KV_ATTENTION_KV_LORA_RANK,
-    LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
-    LLM_KV_ATTENTION_SLIDING_WINDOW,
-    LLM_KV_ATTENTION_SCALE,
-
-    LLM_KV_ROPE_DIMENSION_COUNT,
-    LLM_KV_ROPE_FREQ_BASE,
-    LLM_KV_ROPE_SCALE_LINEAR,
-    LLM_KV_ROPE_SCALING_TYPE,
-    LLM_KV_ROPE_SCALING_FACTOR,
-    LLM_KV_ROPE_SCALING_ATTN_FACTOR,
-    LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
-    LLM_KV_ROPE_SCALING_FINETUNED,
-    LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
-
-    LLM_KV_SPLIT_NO,
-    LLM_KV_SPLIT_COUNT,
-    LLM_KV_SPLIT_TENSORS_COUNT,
-
-    LLM_KV_SSM_INNER_SIZE,
-    LLM_KV_SSM_CONV_KERNEL,
-    LLM_KV_SSM_STATE_SIZE,
-    LLM_KV_SSM_TIME_STEP_RANK,
-
-    LLM_KV_TOKENIZER_MODEL,
-    LLM_KV_TOKENIZER_PRE,
-    LLM_KV_TOKENIZER_LIST,
-    LLM_KV_TOKENIZER_TOKEN_TYPE,
-    LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT,
-    LLM_KV_TOKENIZER_SCORES,
-    LLM_KV_TOKENIZER_MERGES,
-    LLM_KV_TOKENIZER_BOS_ID,
-    LLM_KV_TOKENIZER_EOS_ID,
-    LLM_KV_TOKENIZER_UNK_ID,
-    LLM_KV_TOKENIZER_SEP_ID,
-    LLM_KV_TOKENIZER_PAD_ID,
-    LLM_KV_TOKENIZER_CLS_ID,
-    LLM_KV_TOKENIZER_MASK_ID,
-    LLM_KV_TOKENIZER_ADD_BOS,
-    LLM_KV_TOKENIZER_ADD_EOS,
-    LLM_KV_TOKENIZER_ADD_PREFIX,
-    LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
-    LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
-    LLM_KV_TOKENIZER_HF_JSON,
-    LLM_KV_TOKENIZER_RWKV,
-    LLM_KV_TOKENIZER_CHAT_TEMPLATE,
-    LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
-    LLM_KV_TOKENIZER_FIM_PRE_ID,
-    LLM_KV_TOKENIZER_FIM_SUF_ID,
-    LLM_KV_TOKENIZER_FIM_MID_ID,
-    LLM_KV_TOKENIZER_FIM_PAD_ID,
-    LLM_KV_TOKENIZER_FIM_REP_ID,
-    LLM_KV_TOKENIZER_FIM_SEP_ID,
-    LLM_KV_TOKENIZER_PREFIX_ID,
-    LLM_KV_TOKENIZER_SUFFIX_ID,
-    LLM_KV_TOKENIZER_MIDDLE_ID,
-    LLM_KV_TOKENIZER_EOT_ID,
-    LLM_KV_TOKENIZER_EOM_ID,
-
-    LLM_KV_ADAPTER_TYPE,
-    LLM_KV_ADAPTER_LORA_ALPHA,
-};
+llm_arch llm_arch_from_string(const std::string & name) {
+    for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT
+        if (kv.second == name) {
+            return kv.first;
+        }
+    }
+
+    return LLM_ARCH_UNKNOWN;
+}
 
 static const std::map LLM_KV_NAMES = {
     { LLM_KV_GENERAL_TYPE,                  "general.type"                          },
@@ -525,6 +346,7 @@ static const std::map LLM_KV_NAMES = {
     { LLM_KV_TOKENIZER_MASK_ID,              "tokenizer.ggml.mask_token_id"            },
     { LLM_KV_TOKENIZER_ADD_BOS,              "tokenizer.ggml.add_bos_token"            },
     { LLM_KV_TOKENIZER_ADD_EOS,              "tokenizer.ggml.add_eos_token"            },
+    { LLM_KV_TOKENIZER_ADD_SEP,              "tokenizer.ggml.add_sep_token"            },
     { LLM_KV_TOKENIZER_ADD_PREFIX,           "tokenizer.ggml.add_space_prefix"         },
     { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,      "tokenizer.ggml.remove_extra_whitespaces" },
     { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap"     },
@@ -549,109 +371,6 @@ static const std::map LLM_KV_NAMES = {
     { LLM_KV_ADAPTER_LORA_ALPHA,            "adapter.lora.alpha" },
 };
 
-struct LLM_KV {
-    LLM_KV(llm_arch arch, const char* suffix = nullptr);
-
-    llm_arch arch;
-    const char* suffix;
-    std::string operator()(llm_kv kv) const;
-};
-
-enum llm_tensor {
-    LLM_TENSOR_TOKEN_EMBD,
-    LLM_TENSOR_TOKEN_EMBD_NORM,
-    LLM_TENSOR_TOKEN_TYPES,
-    LLM_TENSOR_POS_EMBD,
-    LLM_TENSOR_OUTPUT,
-    LLM_TENSOR_OUTPUT_NORM,
-    LLM_TENSOR_ROPE_FREQS,
-    LLM_TENSOR_ROPE_FACTORS_LONG,
-    LLM_TENSOR_ROPE_FACTORS_SHORT,
-    LLM_TENSOR_ATTN_Q,
-    LLM_TENSOR_ATTN_K,
-    LLM_TENSOR_ATTN_V,
-    LLM_TENSOR_ATTN_QKV,
-    LLM_TENSOR_ATTN_OUT,
-    LLM_TENSOR_ATTN_NORM,
-    LLM_TENSOR_ATTN_NORM_2,
-    LLM_TENSOR_ATTN_OUT_NORM,
-    LLM_TENSOR_ATTN_POST_NORM,
-    LLM_TENSOR_ATTN_ROT_EMBD,
-    LLM_TENSOR_FFN_GATE_INP,
-    LLM_TENSOR_FFN_GATE_INP_SHEXP,
-    LLM_TENSOR_FFN_NORM,
-    LLM_TENSOR_FFN_POST_NORM,
-    LLM_TENSOR_FFN_GATE,
-    LLM_TENSOR_FFN_DOWN,
-    LLM_TENSOR_FFN_UP,
-    LLM_TENSOR_FFN_ACT,
-    LLM_TENSOR_FFN_DOWN_EXP,  // split experts for backward compatibility
-    LLM_TENSOR_FFN_GATE_EXP,
-    LLM_TENSOR_FFN_UP_EXP,
-    LLM_TENSOR_FFN_NORM_EXPS,
-    LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
-    LLM_TENSOR_FFN_GATE_EXPS,
-    LLM_TENSOR_FFN_UP_EXPS,
-    LLM_TENSOR_FFN_DOWN_SHEXP,
-    LLM_TENSOR_FFN_GATE_SHEXP,
-    LLM_TENSOR_FFN_UP_SHEXP,
-    LLM_TENSOR_FFN_EXP_PROBS_B,
-    LLM_TENSOR_ATTN_Q_NORM,
-    LLM_TENSOR_ATTN_K_NORM,
-    LLM_TENSOR_LAYER_OUT_NORM,
-    LLM_TENSOR_SSM_IN,
-    LLM_TENSOR_SSM_CONV1D,
-    LLM_TENSOR_SSM_X,
-    LLM_TENSOR_SSM_DT,
-    LLM_TENSOR_SSM_A,
-    LLM_TENSOR_SSM_D,
-    LLM_TENSOR_SSM_OUT,
-    LLM_TENSOR_ATTN_Q_A,
-    LLM_TENSOR_ATTN_Q_B,
-    LLM_TENSOR_ATTN_KV_A_MQA,
-    LLM_TENSOR_ATTN_KV_B,
-    LLM_TENSOR_ATTN_K_B,
-    LLM_TENSOR_ATTN_V_B,
-    LLM_TENSOR_ATTN_Q_A_NORM,
-    LLM_TENSOR_ATTN_KV_A_NORM,
-    LLM_TENSOR_ATTN_SUB_NORM,
-    LLM_TENSOR_FFN_SUB_NORM,
-    LLM_TENSOR_DEC_ATTN_NORM,
-    LLM_TENSOR_DEC_ATTN_Q,
-    LLM_TENSOR_DEC_ATTN_K,
-    LLM_TENSOR_DEC_ATTN_V,
-    LLM_TENSOR_DEC_ATTN_OUT,
-    LLM_TENSOR_DEC_ATTN_REL_B,
-    LLM_TENSOR_DEC_CROSS_ATTN_NORM,
-    LLM_TENSOR_DEC_CROSS_ATTN_Q,
-    LLM_TENSOR_DEC_CROSS_ATTN_K,
-    LLM_TENSOR_DEC_CROSS_ATTN_V,
-    LLM_TENSOR_DEC_CROSS_ATTN_OUT,
-    LLM_TENSOR_DEC_CROSS_ATTN_REL_B,
-    LLM_TENSOR_DEC_FFN_NORM,
-    LLM_TENSOR_DEC_FFN_GATE,
-    LLM_TENSOR_DEC_FFN_DOWN,
-    LLM_TENSOR_DEC_FFN_UP,
-    LLM_TENSOR_DEC_OUTPUT_NORM,
-    LLM_TENSOR_ENC_ATTN_NORM,
-    LLM_TENSOR_ENC_ATTN_Q,
-    LLM_TENSOR_ENC_ATTN_K,
-    LLM_TENSOR_ENC_ATTN_V,
-    LLM_TENSOR_ENC_ATTN_OUT,
-    LLM_TENSOR_ENC_ATTN_REL_B,
-    LLM_TENSOR_ENC_FFN_NORM,
-    LLM_TENSOR_ENC_FFN_GATE,
-    LLM_TENSOR_ENC_FFN_DOWN,
-    LLM_TENSOR_ENC_FFN_UP,
-    LLM_TENSOR_ENC_OUTPUT_NORM,
-    LLM_TENSOR_NEXTN_EH_PROJ,
-    LLM_TENSOR_NEXTN_EMBED_TOKENS,
-    LLM_TENSOR_NEXTN_ENORM,
-    LLM_TENSOR_NEXTN_HNORM,
-    LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
-    LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
-};
-
 LLM_KV::LLM_KV(llm_arch arch, const char* suffix) : arch(arch), suffix(suffix) {}
 
 std::string LLM_KV::operator()(llm_kv kv) const {
@@ -1732,6 +1451,25 @@ static const std::map> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
         },
     },
+    {
+        LLM_ARCH_OPENAI_MOE,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_POST_NORM,     "blk.%d.post_attention_norm" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_SINKS,         "blk.%d.attn_sinks" },
+            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -1778,6 +1516,7 @@ enum llm_chat_template {
     LLM_CHAT_TEMPLATE_DOTS1,
     LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
     LLM_CHAT_TEMPLATE_KIMI_K2,
+    LLM_CHAT_TEMPLATE_OPENAI_MOE,
     LLM_CHAT_TEMPLATE_UNKNOWN,
 };
 
@@ -1817,20 +1556,10 @@ static const std::map LLM_CHAT_TEMPLATES = {
     { "llama4",            LLM_CHAT_TEMPLATE_LLAMA4            },
     { "hunyuan-moe",       LLM_CHAT_TEMPLATE_HUNYUAN_MOE       },
     { "kimi-k2",           LLM_CHAT_TEMPLATE_KIMI_K2           },
+    { "gpt-oss",           LLM_CHAT_TEMPLATE_OPENAI_MOE        },
     { "bitnet",            LLM_CHAT_TEMPLATE_BITNET            },
 };
 
-
-static llm_arch llm_arch_from_string(const std::string & name) {
-    for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT
-        if (kv.second == name) {
-            return kv.first;
-        }
-    }
-
-    return LLM_ARCH_UNKNOWN;
-}
-
 // helper to handle gguf constants
 // usage:
 //
@@ -1918,7 +1647,7 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int
     }
 }
 
-static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
+std::string gguf_kv_to_str(const gguf_context * ctx_gguf, int i) {
     const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
 
     switch (type) {
@@ -1959,627 +1688,6 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
 // llama helpers
 //
 
-#if defined(_WIN32)
-static std::string llama_format_win_err(DWORD err) {
-    LPSTR buf;
-    size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
-                                 NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
-    if (!size) {
-        return "FormatMessageA failed";
-    }
-    std::string ret(buf, size);
-    LocalFree(buf);
-    return ret;
-}
-#endif
-
-template 
-struct no_init {
-    T value;
-    no_init() { /* do nothing */ }
-};
-
-struct llama_file {
-
-#if defined(_WIN32)
-    // use FILE * so we don't have to re-open the file to mmap
-    FILE * fp;
-    HANDLE fp_win32;
-    size_t size;
-
-private:
-    std::string GetErrorMessageWin32(DWORD error_code) const {
-        std::string ret;
-        LPSTR lpMsgBuf = NULL;
-        DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
-                                    NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
-        if (!bufLen) {
-            ret = format("Win32 error code: %s", error_code);
-        } else {
-            ret = lpMsgBuf;
-            LocalFree(lpMsgBuf);
-        }
-
-        return ret;
-    }
-
-public:
-
-    llama_file(const char * fname, const char * mode) {
-        fp = ggml_fopen(fname, mode);
-        if (fp == NULL) {
-            throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
-        }
-        fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp));
-        seek(0, SEEK_END);
-        size = tell();
-        seek(0, SEEK_SET);
-    }
-
-    size_t tell() const {
-        // SetFilePointerEx returns the current position when seeking relative 0 bytes
-        LARGE_INTEGER li;
-        li.QuadPart = 0;
-        BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT);
-        if (!ret) {
-            throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
-        }
-
-        return li.QuadPart;
-    }
-
-    void seek(size_t offset, int whence) const {
-        // no need to convert SEEK_* to FILE_*. The enums are the same.
-        // Still, keep static asserts to avoid failures in the future.
-        static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN");
-        static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT");
-        static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END");
-
-        LARGE_INTEGER li;
-        li.QuadPart = offset;
-        BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence);
-        if (!ret) {
-            throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
-        }
-    }
-
-    void read_raw(void * ptr, size_t len) const {
-        // On Win32 ReadFile is significant faster than fread which is again significant faster than std::fstream. Thus
-        // use the Win32 API to do file io instead of the C/C++ library functions.
-
-        // There are conditions under which ReadFile cannot read chunks >64MB.
-        // Thus split the operation into smaller chunks if len exceeds this limit.
-        size_t bytes_read = 0;
-        while (bytes_read < len) {
-            size_t chunk_size = std::min(len - bytes_read, 64*1024*1024);
-            DWORD chunk_read = 0;
-            BOOL result = ReadFile(fp_win32, reinterpret_cast(ptr) + bytes_read, chunk_size, &chunk_read, NULL);
-            if (!result) {
-                throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
-            }
-            if (chunk_read < chunk_size || chunk_read == 0) {
-                throw std::runtime_error("unexpectedly reached end of file");
-            }
-
-            bytes_read += chunk_read;
-        } ;
-    }
-
-    uint32_t read_u32() const {
-        uint32_t val;
-        read_raw(&val, sizeof(val));
-        return val;
-    }
-
-    void write_raw(const void * ptr, size_t len) const {
-        // There are conditions under which WriteFile cannot write chunks >64MB.
-        // Thus split the operation into smaller chunks if len exceeds this limit.
-        size_t bytes_written = 0;
-        while (bytes_written < len) {
-            size_t chunk_size = std::min(len - bytes_written, 64*1024*1024);
-            DWORD chunk_written = 0;
-            BOOL result = WriteFile(fp_win32, reinterpret_cast(ptr) + bytes_written, chunk_size, &chunk_written, NULL);
-            if (!result) {
-                throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
-            }
-            if (chunk_written < chunk_size || chunk_written == 0) {
-                throw std::runtime_error("unexpectedly failed to write bytes");
-            }
-
-            bytes_written += chunk_written;
-        }
-    }
-
-    void write_u32(std::uint32_t val) const {
-        write_raw(&val, sizeof(val));
-    }
-
-    ~llama_file() {
-        if (fp) {
-            std::fclose(fp);
-        }
-    }
-#else
-    // use FILE * so we don't have to re-open the file to mmap
-    FILE * fp;
-    size_t size;
-
-    llama_file(const char * fname, const char * mode) {
-        fp = ggml_fopen(fname, mode);
-        if (fp == NULL) {
-            throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
-        }
-        seek(0, SEEK_END);
-        size = tell();
-        seek(0, SEEK_SET);
-    }
-
-    size_t tell() const {
-#ifdef _WIN32
-        __int64 ret = _ftelli64(fp);
-#else
-        long ret = std::ftell(fp);
-#endif
-        if (ret == -1) {
-            throw std::runtime_error(format("ftell error: %s", strerror(errno)));
-        }
-
-        return (size_t) ret;
-    }
-
-    void seek(size_t offset, int whence) const {
-#ifdef _WIN32
-        int ret = _fseeki64(fp, (__int64) offset, whence);
-#else
-        int ret = std::fseek(fp, (long) offset, whence);
-#endif
-        if (ret != 0) {
-            throw std::runtime_error(format("seek error: %s", strerror(errno)));
-        }
-    }
-
-    void read_raw(void * ptr, size_t len) const {
-        if (len == 0) {
-            return;
-        }
-        errno = 0;
-        std::size_t ret = std::fread(ptr, len, 1, fp);
-        if (ferror(fp)) {
-            throw std::runtime_error(format("read error: %s", strerror(errno)));
-        }
-        if (ret != 1) {
-            throw std::runtime_error("unexpectedly reached end of file");
-        }
-    }
-
-    uint32_t read_u32() const {
-        uint32_t ret;
-        read_raw(&ret, sizeof(ret));
-        return ret;
-    }
-
-    void write_raw(const void * ptr, size_t len) const {
-        if (len == 0) {
-            return;
-        }
-        errno = 0;
-        size_t ret = std::fwrite(ptr, len, 1, fp);
-        if (ret != 1) {
-            throw std::runtime_error(format("write error: %s", strerror(errno)));
-        }
-    }
-
-    void write_u32(std::uint32_t val) const {
-        write_raw(&val, sizeof(val));
-    }
-
-    ~llama_file() {
-        if (fp) {
-            std::fclose(fp);
-        }
-    }
-#endif
-};
-using llama_files = std::vector>;
-
-struct llama_mmap {
-    void * addr;
-    size_t size;
-    size_t mapped_page_size = 0;
-
-    llama_mmap(const llama_mmap &) = delete;
-
-#ifdef _POSIX_MAPPED_FILES
-    static constexpr bool SUPPORTED = true;
-
-    // list of mapped fragments (first_offset, last_offset)
-    std::vector> mapped_fragments;
-
-    llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false, [[maybe_unused]] bool use_thp = false) {
-        size = file->size;
-        int fd = fileno(file->fp);
-        int flags = MAP_SHARED;
-        // prefetch/readahead impairs performance on NUMA systems
-        if (numa)  { prefetch = 0; }
-#ifdef __linux__
-        // advise the kernel to read the file sequentially (increases readahead)
-        if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) {
-            LLAMA_LOG_WARN("warning: posix_fadvise(.., POSIX_FADV_SEQUENTIAL) failed: %s\n",
-                    strerror(errno));
-        }
-        if (prefetch) { flags |= MAP_POPULATE; }
-        if (use_thp) {
-            size_t huge = get_default_huge_page_size();
-            auto size = huge*((file->size + huge - 1)/huge);
-            addr = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB, -1, 0);
-            if (addr != MAP_FAILED) {
-                printf("%s: using THP with page size %zu MiB ", __func__, huge/(1024*1024));
-                fflush(stdout);
-                size_t tot = 0;
-                while (tot < file->size) {
-                    auto n_read = pread(fd, static_cast(addr) + tot, file->size - tot, tot);
-                    if (n_read < 0) throw std::runtime_error(format("Reading into mapped huge pages failed at %zu (%s)", tot, strerror(errno)));
-                    printf(".");  fflush(stdout);
-                    tot += n_read;
-                }
-                printf(" done\n");
-                mapped_fragments.emplace_back(0, file->size);
-                mapped_page_size = huge;
-                return;
-            }
-            else {
-                fprintf(stderr, "%s: mmap with huge page size %zu MiB failed (%s)\n", __func__, huge/(1024*1024), strerror(errno));
-            }
-        }
-#endif
-        addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
-        if (addr == MAP_FAILED) { // NOLINT
-            throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
-        }
-
-        if (prefetch > 0) {
-            // advise the kernel to preload the mapped memory
-            if (posix_madvise(addr, std::min(file->size, prefetch), POSIX_MADV_WILLNEED)) {
-                LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n",
-                        strerror(errno));
-            }
-        }
-        if (numa) {
-            // advise the kernel not to use readahead
-            // (because the next page might not belong on the same node)
-            if (posix_madvise(addr, file->size, POSIX_MADV_RANDOM)) {
-                LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n",
-                        strerror(errno));
-            }
-        }
-
-        // initialize list of mapped_fragments
-        mapped_fragments.emplace_back(0, file->size);
-    }
-
-    static void align_range(size_t * first, size_t * last, size_t page_size) {
-        // align first to the next page
-        size_t offset_in_page = *first & (page_size - 1);
-        size_t offset_to_page = offset_in_page == 0 ? 0 : page_size - offset_in_page;
-        *first += offset_to_page;
-
-        // align last to the previous page
-        *last = *last & ~(page_size - 1);
-
-        if (*last <= *first) {
-            *last = *first;
-        }
-    }
-
-    // partially unmap the file in the range [first, last)
-    void unmap_fragment(size_t first, size_t last) {
-        // note: this function must not be called multiple times with overlapping ranges
-        // otherwise, there is a risk of invalidating addresses that have been repurposed for other mappings
-        int page_size = mapped_page_size > 0 ? mapped_page_size : sysconf(_SC_PAGESIZE);
-        align_range(&first, &last, page_size);
-        size_t len = last - first;
-
-        if (len == 0) {
-            return;
-        }
-
-        GGML_ASSERT(first % page_size == 0);
-        GGML_ASSERT(last % page_size == 0);
-        GGML_ASSERT(last > first);
-
-        void * next_page_start = (uint8_t *) addr + first;
-
-        // unmap the range
-        if (munmap(next_page_start, len)) {
-            LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
-        }
-
-        // update the list of mapped fragments to avoid unmapping the same range again in the destructor
-        std::vector> new_mapped_fragments;
-        for (const auto & frag : mapped_fragments) {
-            if (frag.first < first && frag.second > last) {
-                // the range is in the middle of the fragment, split it
-                new_mapped_fragments.emplace_back(frag.first, first);
-                new_mapped_fragments.emplace_back(last, frag.second);
-            } else if (frag.first < first && frag.second > first) {
-                // the range starts in the middle of the fragment
-                new_mapped_fragments.emplace_back(frag.first, first);
-            } else if (frag.first < last && frag.second > last) {
-                // the range ends in the middle of the fragment
-                new_mapped_fragments.emplace_back(last, frag.second);
-            } else if (frag.first >= first && frag.second <= last) {
-                // the range covers the entire fragment
-            } else {
-                // the range is outside the fragment
-                new_mapped_fragments.push_back(frag);
-            }
-        }
-        mapped_fragments = std::move(new_mapped_fragments);
-    }
-
-#ifdef __linux__
-    static int get_default_huge_page_size() {
-        int pg_size = 2048;
-        std::ifstream in("/proc/meminfo");
-        if (in) {
-            std::string line;
-            while (true) {
-                std::getline(in, line);
-                if (in.fail()) break;
-                if (auto pos = line.find("Hugepagesize:"); pos != std::string::npos) {
-                    std::istringstream str(line.data() + pos + 13);
-                    int aux;
-                    str >> aux;
-                    if (!str.fail()) pg_size = aux;
-                    break;
-                }
-            }
-        }
-        return pg_size * 1024;
-    }
-#endif
-
-    ~llama_mmap() {
-        for (const auto & frag : mapped_fragments) {
-            if (munmap((char *) addr + frag.first, frag.second - frag.first)) {
-                LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
-            }
-        }
-    }
-#elif defined(_WIN32)
-    static constexpr bool SUPPORTED = true;
-
-    llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false, [[maybe_unused]] bool use_thp = false) {
-        GGML_UNUSED(numa);
-
-        size = file->size;
-
-        HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp));
-
-        HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
-
-        if (hMapping == NULL) {
-            DWORD error = GetLastError();
-            throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
-        }
-
-        addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
-        DWORD error = GetLastError();
-        CloseHandle(hMapping);
-
-        if (addr == NULL) {
-            throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()));
-        }
-
-        if (prefetch > 0) {
-#if _WIN32_WINNT >= 0x602
-            // PrefetchVirtualMemory is only present on Windows 8 and above, so we dynamically load it
-            BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG);
-            HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll");
-
-            // may fail on pre-Windows 8 systems
-            pPrefetchVirtualMemory = reinterpret_cast (GetProcAddress(hKernel32, "PrefetchVirtualMemory"));
-
-            if (pPrefetchVirtualMemory) {
-                // advise the kernel to preload the mapped memory
-                WIN32_MEMORY_RANGE_ENTRY range;
-                range.VirtualAddress = addr;
-                range.NumberOfBytes = (SIZE_T) std::min(size, prefetch);
-                if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
-                    LLAMA_LOG_WARN("warning: PrefetchVirtualMemory failed: %s\n",
-                            llama_format_win_err(GetLastError()).c_str());
-                }
-            }
-#else
-            throw std::runtime_error("PrefetchVirtualMemory unavailable");
-#endif
-        }
-    }
-
-    void unmap_fragment(size_t first, size_t last) {
-        // not supported
-        GGML_UNUSED(first);
-        GGML_UNUSED(last);
-    }
-
-    ~llama_mmap() {
-        if (!UnmapViewOfFile(addr)) {
-            LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n",
-                    llama_format_win_err(GetLastError()).c_str());
-        }
-    }
-#else
-    static constexpr bool SUPPORTED = false;
-
-    llama_mmap(struct llama_file * file, size_t prefetch = -1, bool numa = false, bool use_thp = false) {
-        GGML_UNUSED(file);
-        GGML_UNUSED(prefetch);
-        GGML_UNUSED(numa);
-        GGML_UNUSED(use_thp);
-
-        throw std::runtime_error("mmap not supported");
-    }
-
-    void unmap_fragment(size_t first, size_t last) {
-        GGML_UNUSED(first);
-        GGML_UNUSED(last);
-
-        throw std::runtime_error("mmap not supported");
-    }
-#endif
-};
-using llama_mmaps = std::vector>;
-
-// Represents some region of memory being locked using mlock or VirtualLock;
-// will automatically unlock on destruction.
-struct llama_mlock {
-    void * addr = NULL;
-    size_t size = 0;
-
-    bool failed_already = false;
-
-    llama_mlock() {}
-    llama_mlock(const llama_mlock &) = delete;
-
-    ~llama_mlock() {
-        if (size) {
-            raw_unlock(addr, size);
-        }
-    }
-
-    void init(void * ptr) {
-        GGML_ASSERT(addr == NULL && size == 0); // NOLINT
-        addr = ptr;
-    }
-
-    void grow_to(size_t target_size) {
-        GGML_ASSERT(addr);
-        if (failed_already) {
-            return;
-        }
-        size_t granularity = lock_granularity();
-        target_size = (target_size + granularity - 1) & ~(granularity - 1);
-        if (target_size > size) {
-            if (raw_lock((uint8_t *) addr + size, target_size - size)) {
-                size = target_size;
-            } else {
-                failed_already = true;
-            }
-        }
-    }
-
-#ifdef _POSIX_MEMLOCK_RANGE
-    static constexpr bool SUPPORTED = true;
-
-    static size_t lock_granularity() {
-        return (size_t) sysconf(_SC_PAGESIZE);
-    }
-
-    #ifdef __APPLE__
-        #define MLOCK_SUGGESTION \
-            "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
-            "decreasing 'vm.global_no_user_wire_amount'.  Also try increasing RLIMIT_MEMLOCK (ulimit -l).\n"
-    #else
-        #define MLOCK_SUGGESTION \
-            "Try increasing RLIMIT_MEMLOCK ('ulimit -l' as root).\n"
-    #endif
-
-    bool raw_lock(const void * addr, size_t size) const {
-        if (!mlock(addr, size)) {
-            return true;
-        }
-
-        char* errmsg = std::strerror(errno);
-        bool suggest = (errno == ENOMEM);
-
-        // Check if the resource limit is fine after all
-        struct rlimit lock_limit;
-        if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) {
-            suggest = false;
-        }
-        if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) {
-            suggest = false;
-        }
-
-        LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
-                size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
-        return false;
-    }
-
-    #undef MLOCK_SUGGESTION
-
-    static void raw_unlock(void * addr, size_t size) {
-        if (munlock(addr, size)) {
-            LLAMA_LOG_WARN("warning: failed to munlock buffer: %s\n", std::strerror(errno));
-        }
-    }
-#elif defined(_WIN32)
-    static constexpr bool SUPPORTED = true;
-
-    static size_t lock_granularity() {
-        SYSTEM_INFO si;
-        GetSystemInfo(&si);
-        return (size_t) si.dwPageSize;
-    }
-
-    bool raw_lock(void * ptr, size_t len) const {
-        for (int tries = 1; ; tries++) {
-            if (VirtualLock(ptr, len)) {
-                return true;
-            }
-            if (tries == 2) {
-                LLAMA_LOG_WARN("warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
-                    len, size, llama_format_win_err(GetLastError()).c_str());
-                return false;
-            }
-
-            // It failed but this was only the first try; increase the working
-            // set size and try again.
-            SIZE_T min_ws_size, max_ws_size;
-            if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) {
-                LLAMA_LOG_WARN("warning: GetProcessWorkingSetSize failed: %s\n",
-                        llama_format_win_err(GetLastError()).c_str());
-                return false;
-            }
-            // Per MSDN: "The maximum number of pages that a process can lock
-            // is equal to the number of pages in its minimum working set minus
-            // a small overhead."
-            // Hopefully a megabyte is enough overhead:
-            size_t increment = len + 1048576;
-            // The minimum must be <= the maximum, so we need to increase both:
-            min_ws_size += increment;
-            max_ws_size += increment;
-            if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) {
-                LLAMA_LOG_WARN("warning: SetProcessWorkingSetSize failed: %s\n",
-                        llama_format_win_err(GetLastError()).c_str());
-                return false;
-            }
-        }
-    }
-
-    static void raw_unlock(void * ptr, size_t len) {
-        if (!VirtualUnlock(ptr, len)) {
-            LLAMA_LOG_WARN("warning: failed to VirtualUnlock buffer: %s\n",
-                    llama_format_win_err(GetLastError()).c_str());
-        }
-    }
-#else
-    static constexpr bool SUPPORTED = false;
-
-    static size_t lock_granularity() {
-        return (size_t) 65536;
-    }
-
-    bool raw_lock(const void * addr, size_t len) const {
-        LLAMA_LOG_WARN("warning: mlock not supported on this system\n");
-        return false;
-    }
-
-    static void raw_unlock(const void * addr, size_t len) {}
-#endif
-};
-using llama_mlocks = std::vector>;
-
 // NOTE: avoid ever using this except for building the token_to_piece caches
 static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
     std::string piece;
@@ -2597,7 +1705,7 @@ static std::string llama_token_to_piece(const struct llama_model * model, llama_
     return piece;
 }
 
-static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer) {
+ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer) {
     ggml_backend_buffer_type_t buft = nullptr;
 
 #if defined(GGML_USE_CUDA)
@@ -2723,14 +1831,17 @@ static const size_t MiB = 1024*kiB;
 static const size_t GiB = 1024*MiB;
 
 enum llm_expert_gating_func_type {
-    LLM_EXPERT_GATING_FUNC_SOFTMAX = 1,
-    LLM_EXPERT_GATING_FUNC_SIGMOID = 2,
+    LLM_EXPERT_GATING_FUNC_TYPE_NONE             = 0,
+    LLM_EXPERT_GATING_FUNC_SOFTMAX               = 1,
+    LLM_EXPERT_GATING_FUNC_SIGMOID               = 2,
+    LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT = 3,
 };
 
 static const char * llama_expert_gating_func_name(llm_expert_gating_func_type type) {
     switch (type) {
         case LLM_EXPERT_GATING_FUNC_SOFTMAX: return "softmax";
         case LLM_EXPERT_GATING_FUNC_SIGMOID: return "sigmoid";
+        case LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT: return "softmax_weight";
         default:                             return "unknown";
     }
 }
@@ -2982,107 +2093,114 @@ struct llama_layer_nextn {
 // TODO: separate into "llama_layer_enc" and "llama_layer_dec"
 struct llama_layer {
     // normalization
-    struct ggml_tensor * attn_norm;
-    struct ggml_tensor * attn_norm_b;
-    struct ggml_tensor * attn_norm_2;
-    struct ggml_tensor * attn_norm_2_b;
-    struct ggml_tensor * attn_q_norm;
-    struct ggml_tensor * attn_q_norm_b;
-    struct ggml_tensor * attn_k_norm;
-    struct ggml_tensor * attn_k_norm_b;
-    struct ggml_tensor * attn_out_norm;
-    struct ggml_tensor * attn_out_norm_b;
-    struct ggml_tensor * attn_q_a_norm;
-    struct ggml_tensor * attn_kv_a_norm;
-    struct ggml_tensor * attn_sub_norm;
-    struct ggml_tensor * attn_post_norm;
-    struct ggml_tensor * ffn_sub_norm;
-    struct ggml_tensor * attn_norm_cross;
-    struct ggml_tensor * attn_norm_enc;
+    struct ggml_tensor * attn_norm = nullptr;
+    struct ggml_tensor * attn_norm_b = nullptr;
+    struct ggml_tensor * attn_norm_2 = nullptr;
+    struct ggml_tensor * attn_norm_2_b = nullptr;
+    struct ggml_tensor * attn_q_norm = nullptr;
+    struct ggml_tensor * attn_q_norm_b = nullptr;
+    struct ggml_tensor * attn_k_norm = nullptr;
+    struct ggml_tensor * attn_k_norm_b = nullptr;
+    struct ggml_tensor * attn_out_norm = nullptr;
+    struct ggml_tensor * attn_out_norm_b = nullptr;
+    struct ggml_tensor * attn_q_a_norm = nullptr;
+    struct ggml_tensor * attn_kv_a_norm = nullptr;
+    struct ggml_tensor * attn_sub_norm = nullptr;
+    struct ggml_tensor * attn_post_norm = nullptr;
+    struct ggml_tensor * ffn_sub_norm = nullptr;
+    struct ggml_tensor * attn_norm_cross = nullptr;
+    struct ggml_tensor * attn_norm_enc = nullptr;
 
     // attention
-    struct ggml_tensor * wq;
-    struct ggml_tensor * wk;
-    struct ggml_tensor * wv;
-    struct ggml_tensor * wo;
-    struct ggml_tensor * wqkv;
-    struct ggml_tensor * wq_a;
-    struct ggml_tensor * wq_b;
-    struct ggml_tensor * wkv_a_mqa;
-    struct ggml_tensor * wkv_b;
-    struct ggml_tensor * wk_b;
-    struct ggml_tensor * wv_b;
-    struct ggml_tensor * wq_cross;
-    struct ggml_tensor * wk_cross;
-    struct ggml_tensor * wv_cross;
-    struct ggml_tensor * wo_cross;
-    struct ggml_tensor * wq_enc;
-    struct ggml_tensor * wk_enc;
-    struct ggml_tensor * wv_enc;
-    struct ggml_tensor * wo_enc;
+    struct ggml_tensor * wq = nullptr;
+    struct ggml_tensor * wk = nullptr;
+    struct ggml_tensor * wv = nullptr;
+    struct ggml_tensor * wo = nullptr;
+    struct ggml_tensor * wqkv = nullptr;
+    struct ggml_tensor * wq_a = nullptr;
+    struct ggml_tensor * wq_b = nullptr;
+    struct ggml_tensor * wkv_a_mqa = nullptr;
+    struct ggml_tensor * wkv_b = nullptr;
+    struct ggml_tensor * wk_b = nullptr;
+    struct ggml_tensor * wv_b = nullptr;
+    struct ggml_tensor * wq_cross = nullptr;
+    struct ggml_tensor * wk_cross = nullptr;
+    struct ggml_tensor * wv_cross = nullptr;
+    struct ggml_tensor * wo_cross = nullptr;
+    struct ggml_tensor * wq_enc = nullptr;
+    struct ggml_tensor * wk_enc = nullptr;
+    struct ggml_tensor * wv_enc = nullptr;
+    struct ggml_tensor * wo_enc = nullptr;
+    struct ggml_tensor * attn_sinks = nullptr;
 
     // attention bias
-    struct ggml_tensor * bq;
-    struct ggml_tensor * bk;
-    struct ggml_tensor * bv;
-    struct ggml_tensor * bo;
-    struct ggml_tensor * bqkv;
+    struct ggml_tensor * bq = nullptr;
+    struct ggml_tensor * bk = nullptr;
+    struct ggml_tensor * bv = nullptr;
+    struct ggml_tensor * bo = nullptr;
+    struct ggml_tensor * bqkv = nullptr;
 
     // relative position bias
-    struct ggml_tensor * attn_rel_b;
-    struct ggml_tensor * attn_rel_b_enc;
-    struct ggml_tensor * attn_rel_b_cross;
+    struct ggml_tensor * attn_rel_b = nullptr;
+    struct ggml_tensor * attn_rel_b_enc = nullptr;
+    struct ggml_tensor * attn_rel_b_cross = nullptr;
 
     // normalization
-    struct ggml_tensor * ffn_norm;
-    struct ggml_tensor * ffn_norm_b;
-    struct ggml_tensor * ffn_post_norm;
-    struct ggml_tensor * layer_out_norm;
-    struct ggml_tensor * layer_out_norm_b;
-    struct ggml_tensor * ffn_norm_exps;
-    struct ggml_tensor * ffn_norm_enc;
+    struct ggml_tensor * ffn_norm = nullptr;
+    struct ggml_tensor * ffn_norm_b = nullptr;
+    struct ggml_tensor * ffn_post_norm = nullptr;
+    struct ggml_tensor * layer_out_norm = nullptr;
+    struct ggml_tensor * layer_out_norm_b = nullptr;
+    struct ggml_tensor * ffn_norm_exps = nullptr;
+    struct ggml_tensor * ffn_norm_enc = nullptr;
 
     // ff
-    struct ggml_tensor * ffn_gate; // w1
-    struct ggml_tensor * ffn_down; // w2
-    struct ggml_tensor * ffn_up;   // w3
-    struct ggml_tensor * ffn_gate_enc;
-    struct ggml_tensor * ffn_down_enc;
-    struct ggml_tensor * ffn_up_enc;
+    struct ggml_tensor * ffn_gate = nullptr; // w1
+    struct ggml_tensor * ffn_down = nullptr; // w2
+    struct ggml_tensor * ffn_up = nullptr;   // w3
+    struct ggml_tensor * ffn_gate_enc = nullptr;
+    struct ggml_tensor * ffn_down_enc = nullptr;
+    struct ggml_tensor * ffn_up_enc = nullptr;
 
     // ff MoE
-    struct ggml_tensor * ffn_gate_inp;
-    struct ggml_tensor * ffn_gate_exps;
-    struct ggml_tensor * ffn_down_exps;
-    struct ggml_tensor * ffn_up_exps ;
+    struct ggml_tensor * ffn_gate_inp = nullptr;
+    struct ggml_tensor * ffn_gate_exps = nullptr;
+    struct ggml_tensor * ffn_down_exps = nullptr;
+    struct ggml_tensor * ffn_up_exps  = nullptr;
+
+    // ff MoE bias
+    struct ggml_tensor * ffn_gate_inp_b = nullptr;
+    struct ggml_tensor * ffn_gate_exps_b = nullptr;
+    struct ggml_tensor * ffn_down_exps_b = nullptr;
+    struct ggml_tensor * ffn_up_exps_b = nullptr;
 
     // ff shared expert (shexp)
-    struct ggml_tensor * ffn_gate_inp_shexp;
-    struct ggml_tensor * ffn_gate_shexp;
-    struct ggml_tensor * ffn_down_shexp;
-    struct ggml_tensor * ffn_up_shexp;
+    struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
+    struct ggml_tensor * ffn_gate_shexp = nullptr;
+    struct ggml_tensor * ffn_down_shexp = nullptr;
+    struct ggml_tensor * ffn_up_shexp = nullptr;
 
     // ff bias
     struct ggml_tensor * ffn_gate_b = nullptr;
     struct ggml_tensor * ffn_down_b = nullptr; // b2
     struct ggml_tensor * ffn_up_b   = nullptr; // b3
-    struct ggml_tensor * ffn_act;
+    struct ggml_tensor * ffn_act = nullptr;
     struct ggml_tensor * ffn_exp_probs_b = nullptr;
 
     // mamba proj
-    struct ggml_tensor * ssm_in;
-    struct ggml_tensor * ssm_x;
-    struct ggml_tensor * ssm_dt;
-    struct ggml_tensor * ssm_out;
+    struct ggml_tensor * ssm_in = nullptr;
+    struct ggml_tensor * ssm_x = nullptr;
+    struct ggml_tensor * ssm_dt = nullptr;
+    struct ggml_tensor * ssm_out = nullptr;
 
     // mamba
-    struct ggml_tensor * ssm_conv1d;
-    struct ggml_tensor * ssm_a;
-    struct ggml_tensor * ssm_d;
+    struct ggml_tensor * ssm_conv1d = nullptr;
+    struct ggml_tensor * ssm_a = nullptr;
+    struct ggml_tensor * ssm_d = nullptr;
 
     // mamba bias
-    struct ggml_tensor * ssm_conv1d_b;
-    struct ggml_tensor * ssm_dt_b;
+    struct ggml_tensor * ssm_conv1d_b = nullptr;
+    struct ggml_tensor * ssm_dt_b = nullptr;
 
     // long rope factors
     struct ggml_tensor * rope_long  = nullptr;
@@ -3090,13 +2208,13 @@ struct llama_layer {
     struct ggml_tensor * rope_freqs = nullptr;
 
     // bitnet scale
-    struct ggml_tensor * wq_scale;
-    struct ggml_tensor * wk_scale;
-    struct ggml_tensor * wv_scale;
-    struct ggml_tensor * wo_scale;
-    struct ggml_tensor * ffn_gate_scale;
-    struct ggml_tensor * ffn_up_scale;
-    struct ggml_tensor * ffn_down_scale;
+    struct ggml_tensor * wq_scale = nullptr;
+    struct ggml_tensor * wk_scale = nullptr;
+    struct ggml_tensor * wv_scale = nullptr;
+    struct ggml_tensor * wo_scale = nullptr;
+    struct ggml_tensor * ffn_gate_scale = nullptr;
+    struct ggml_tensor * ffn_up_scale = nullptr;
+    struct ggml_tensor * ffn_down_scale = nullptr;
 
     struct llama_layer_nextn nextn;
 
@@ -4096,1158 +3214,15 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
 // model loading and saving
 //
 
-enum llama_fver {
-    GGUF_FILE_VERSION_V1 = 1,
-    GGUF_FILE_VERSION_V2 = 2,
-    GGUF_FILE_VERSION_V3 = 3,
-};
-
-static const char * llama_file_version_name(llama_fver version) {
-    switch (version) {
-        case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)";
-        case GGUF_FILE_VERSION_V2: return "GGUF V2";
-        case GGUF_FILE_VERSION_V3: return "GGUF V3 (latest)";
-    }
-
-    return "unknown";
-}
-
-static std::string llama_format_tensor_shape(const std::vector & ne) {
-    char buf[256];
-    snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0));
-    for (size_t i = 1; i < ne.size(); i++) {
-        snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i));
-    }
-    return buf;
-}
-
-static std::string llama_format_tensor_shape(const struct ggml_tensor * t) {
-    char buf[256];
-    snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]);
-    for (int i = 1; i < GGML_MAX_DIMS; i++) {
-        snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]);
-    }
-    return buf;
-}
-
-namespace GGUFMeta {
-    template 
-    struct GKV_Base_Type {
-        static constexpr gguf_type gt = gt_;
-
-        static T getter(const gguf_context * ctx, const int kid) {
-            return gfun(ctx, kid);
-        }
-    };
-
-    template struct GKV_Base;
-
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-
-    template<> struct GKV_Base {
-        static constexpr gguf_type gt = GGUF_TYPE_STRING;
-
-        static std::string getter(const gguf_context * ctx, const int kid) {
-            return gguf_get_val_str(ctx, kid);
-        }
-    };
-
-    struct ArrayInfo {
-        const gguf_type gt;
-        const size_t length;
-        const void * data;
-    };
-
-    template<> struct GKV_Base {
-        public:
-        static constexpr gguf_type gt = GGUF_TYPE_ARRAY;
-        static ArrayInfo getter(const gguf_context *ctx, const int k) {
-            return ArrayInfo {
-                gguf_get_arr_type(ctx, k),
-                size_t(gguf_get_arr_n(ctx, k)),
-                gguf_get_arr_data(ctx, k),
-            };
-        }
-    };
-
-    template
-    class GKV : public GKV_Base {
-        GKV() = delete;
-
-        public:
-        static T get_kv(const gguf_context * ctx, const int k) {
-            const enum gguf_type kt = gguf_get_kv_type(ctx, k);
-
-            if (kt != GKV::gt) {
-                throw std::runtime_error(format("key %s has wrong type %s but expected type %s",
-                    gguf_get_key(ctx, k), gguf_type_name(kt), gguf_type_name(GKV::gt)));
-            }
-            return GKV::getter(ctx, k);
-        }
-
-        static const char * override_type_to_str(const llama_model_kv_override_type ty) {
-            switch (ty) {
-                case LLAMA_KV_OVERRIDE_TYPE_BOOL:  return "bool";
-                case LLAMA_KV_OVERRIDE_TYPE_INT:   return "int";
-                case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float";
-                case LLAMA_KV_OVERRIDE_TYPE_STR:   return "str";
-            }
-            return "unknown";
-        }
-
-        static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override * ovrd) {
-            if (!ovrd) { return false; }
-            if (ovrd->tag == expected_type) {
-                LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ",
-                    __func__, override_type_to_str(ovrd->tag), ovrd->key);
-                switch (ovrd->tag) {
-                    case LLAMA_KV_OVERRIDE_TYPE_BOOL:  {
-                        LLAMA_LOG_INFO("%s\n", ovrd->val_bool ? "true" : "false");
-                    } break;
-                    case LLAMA_KV_OVERRIDE_TYPE_INT:   {
-                        LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->val_i64);
-                    } break;
-                    case LLAMA_KV_OVERRIDE_TYPE_FLOAT: {
-                        LLAMA_LOG_INFO("%.6f\n", ovrd->val_f64);
-                    } break;
-                    case LLAMA_KV_OVERRIDE_TYPE_STR: {
-                        LLAMA_LOG_INFO("%s\n", ovrd->val_str);
-                    } break;
-                    default:
-                        // Shouldn't be possible to end up here, but just in case...
-                        throw std::runtime_error(
-                            format("Unsupported attempt to override %s type for metadata key %s\n",
-                                override_type_to_str(ovrd->tag), ovrd->key));
-                }
-                return true;
-            }
-            LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n",
-                __func__, ovrd->key, override_type_to_str(expected_type), override_type_to_str(ovrd->tag));
-            return false;
-        }
-
-        template
-        static typename std::enable_if::value, bool>::type
-        try_override(OT & target, const struct llama_model_kv_override * ovrd) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) {
-                target = ovrd->val_bool;
-                return true;
-            }
-            return false;
-        }
-
-        template
-        static typename std::enable_if::value && std::is_integral::value, bool>::type
-        try_override(OT & target, const struct llama_model_kv_override * ovrd) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) {
-                target = ovrd->val_i64;
-                return true;
-            }
-            return false;
-        }
-
-        template
-        static typename std::enable_if::value, bool>::type
-        try_override(T & target, const struct llama_model_kv_override * ovrd) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) {
-                target = ovrd->val_f64;
-                return true;
-            }
-            return false;
-        }
-
-        template
-        static typename std::enable_if::value, bool>::type
-        try_override(T & target, const struct llama_model_kv_override * ovrd) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) {
-                target = ovrd->val_str;
-                return true;
-            }
-            return false;
-        }
-
-        static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
-            if (try_override(target, ovrd)) {
-                return true;
-            }
-            if (k < 0) { return false; }
-            target = get_kv(ctx, k);
-            return true;
-        }
-
-        static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
-            return set(ctx, gguf_find_key(ctx, key), target, ovrd);
-        }
-
-        static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
-            return set(ctx, key.c_str(), target, ovrd);
-        }
-    };
-}
-
-using llama_buf_map = std::unordered_map;
-
-// TODO: update when needed or think of some clever automatic way to do this
-static size_t llama_model_max_nodes(const llama_model & /*model*/) {
-    //if (model.arch == LLM_ARCH_LLAMA && model.hparams.n_layer > ??) { // llama-3 405B
-    //    return 32768;
-    //}
-
-    return 65536;
-}
-
-struct llama_model_loader {
-    int n_kv      = 0;
-    int n_tensors = 0;
-    int n_created = 0;
-
-    int64_t n_elements = 0;
-    size_t  n_bytes    = 0;
-
-    bool use_mmap = false;
-    bool check_tensors;
-    bool repack_tensors = false;
-    bool use_thp = false;
-
-    llama_files files;
-    llama_ftype ftype;
-    llama_fver  fver;
-
-    llama_mmaps mappings;
-
-    // Holds information on a model weight
-    struct llama_tensor_weight {
-        uint16_t  idx; // source file index
-        size_t   offs; // tensor data offset in the original file
-
-        ggml_tensor * tensor;
-
-        llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
-            const int tensor_idx = gguf_find_tensor(gguf_ctx, name);
-            offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
-
-            if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size) {
-                throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name));
-            }
-        }
-    };
-    std::vector weights;
-
-    std::unordered_map kv_overrides;
-    const llama_model_tensor_buft_override * tensor_buft_overrides;
-
-    struct gguf_context * meta = NULL;
-    std::vector contexts;
-
-    std::string arch_name;
-    LLM_KV      llm_kv    = LLM_KV(LLM_ARCH_UNKNOWN);
-
-    llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp,
-            const llama_model_kv_override * param_overrides_p,
-            const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
-        int trace = 0;
-        if (getenv("LLAMA_TRACE")) {
-            trace = atoi(getenv("LLAMA_TRACE"));
-        }
-
-        #ifdef _WIN32
-            // Only bump maxstdio if the user really wants large contexts:
-            #if defined(GGML_MAX_CONTEXTS) && (GGML_MAX_CONTEXTS > 512)
-                // Cap at MSVC's hard limit of 8192 - https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/setmaxstdio?view=msvc-160
-                #if (GGML_MAX_CONTEXTS > 8192)
-                    #define _GGML_STDIO_TARGET 8192
-                #else
-                    #define _GGML_STDIO_TARGET GGML_MAX_CONTEXTS
-                #endif
-                int _setmaxstdio_ret = _setmaxstdio(_GGML_STDIO_TARGET);
-                if (_setmaxstdio_ret == -1) {
-                    LLAMA_LOG_INFO("%s: failed to set max stdio to %d. (setmaxstdio returned -1)\n", __func__, _GGML_STDIO_TARGET);
-                } else {
-                    LLAMA_LOG_INFO("%s: max stdio successfully set to %d\n", __func__, _setmaxstdio_ret);
-                }
-            #endif // GGML_MAX_CONTEXTS > 512
-        #endif // _WIN32
-
-        if (param_overrides_p != nullptr) {
-            for (const struct llama_model_kv_override * p = param_overrides_p; p->key[0] != 0; p++) {
-                kv_overrides.insert({std::string(p->key), *p});
-            }
-        }
-
-        tensor_buft_overrides = param_tensor_buft_overrides_p;
-
-        struct ggml_context * ctx = NULL;
-        struct gguf_init_params params = {
-            /*.no_alloc = */ true,
-            /*.ctx      = */ &ctx,
-        };
-
-        meta = gguf_init_from_file(fname.c_str(), params);
-        if (!meta) {
-            throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str()));
-        }
-
-        get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
-        llm_kv = LLM_KV(llm_arch_from_string(arch_name));
-
-        files.emplace_back(new llama_file(fname.c_str(), "rb"));
-        contexts.emplace_back(ctx);
-
-        // Save tensors data offset of the main file.
-        // For subsidiary files, `meta` tensor data offset must not be used,
-        // so we build a unified tensors index for weights.
-        for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
-            weights.emplace_back(files.back().get(), 0, cur->name, meta, cur);
-        }
-        uint16_t n_split = 0;
-        get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false);
-
-        // Load additional GGML contexts
-        if (n_split > 1) {
-            uint16_t idx = 0;
-            get_key(llm_kv(LLM_KV_SPLIT_NO), idx);
-            if (idx != 0) {
-                throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", idx));
-            }
-
-            char split_prefix[PATH_MAX] = {0};
-            if (!llama_split_prefix(split_prefix, sizeof(split_prefix), fname.c_str(), idx, n_split)) {
-                throw std::runtime_error(format("invalid split file: %s", fname.c_str()));
-            }
-
-            if (trace > 0) {
-                LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split);
-            }
-
-            char split_path[PATH_MAX] = {0};
-            for (idx = 1; idx < n_split; idx++) {
-                llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split);
-
-                struct gguf_init_params split_params = {
-                    /*.no_alloc = */ true,
-                    /*.ctx      = */ &ctx,
-                };
-                struct gguf_context * ctx_gguf = gguf_init_from_file(split_path, split_params);
-                if (!ctx_gguf) {
-                    throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path));
-                }
-
-                files.emplace_back(new llama_file(split_path, "rb"));
-                contexts.emplace_back(ctx);
-
-                // Save tensors data offset info of the shard.
-                for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
-                    weights.emplace_back(files.back().get(), idx, cur->name, ctx_gguf, cur);
-                }
-
-                gguf_free(ctx_gguf);
-            }
-
-            get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors);
-
-            // sanity check
-            {
-                const int n_tensors_loaded = (int) weights.size();
-                if (n_tensors != n_tensors_loaded) {
-                    throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded));
-                }
-            }
-
-            LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n",  __func__, n_split - 1);
-        }
-
-        n_kv      = gguf_get_n_kv(meta);
-        n_tensors = weights.size();
-
-        fver = (enum llama_fver) gguf_get_version(meta);
-
-        std::set tensor_names;
-        for (auto & w : weights) {
-            n_elements += ggml_nelements(w.tensor);
-            n_bytes    += ggml_nbytes(w.tensor);
-            // make sure there is no duplicated tensor names
-            const std::string name(w.tensor->name);
-            auto found = tensor_names.find(name);
-            if (found != tensor_names.end()) {
-                throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", w.tensor->name));
-            }
-            tensor_names.insert(name);
-        }
-
-        LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n",
-                __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver));
-
-        // determine file type based on the number of tensors for each quantization and print meta data
-        // TODO: make optional
-        {
-            std::map n_type;
-
-            uint32_t n_type_max = 0;
-            enum ggml_type type_max = GGML_TYPE_F32;
-
-            for (int i = 0; i < n_tensors; i++) {
-                const ggml_tensor * tensor = weights.at(i).tensor;
-                enum ggml_type type = tensor->type;
-
-                n_type[type]++;
-
-                if (n_type_max < n_type[type]) {
-                    n_type_max = n_type[type];
-                    type_max   = type;
-                }
-
-                if (trace > 0) {
-                    const uint16_t sid = weights.at(i).idx;
-                    LLAMA_LOG_INFO("%s: - tensor %4d, split %2d: %32s %-8s [ %s ]\n", __func__, i, sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str());
-                }
-            }
-
-            switch (type_max) {
-                case GGML_TYPE_F32:     ftype = LLAMA_FTYPE_ALL_F32;        break;
-                case GGML_TYPE_F16:     ftype = LLAMA_FTYPE_MOSTLY_F16;     break;
-                case GGML_TYPE_BF16:    ftype = LLAMA_FTYPE_MOSTLY_BF16;    break;
-                case GGML_TYPE_BF16_R16:ftype = LLAMA_FTYPE_MOSTLY_BF16_R16;break;
-                case GGML_TYPE_Q4_0:    ftype = LLAMA_FTYPE_MOSTLY_Q4_0;    break;
-                case GGML_TYPE_Q4_1:    ftype = LLAMA_FTYPE_MOSTLY_Q4_1;    break;
-                case GGML_TYPE_Q5_0:    ftype = LLAMA_FTYPE_MOSTLY_Q5_0;    break;
-                case GGML_TYPE_Q5_1:    ftype = LLAMA_FTYPE_MOSTLY_Q5_1;    break;
-                case GGML_TYPE_Q6_0:    ftype = LLAMA_FTYPE_MOSTLY_Q6_0;    break;
-                case GGML_TYPE_Q8_0:    ftype = LLAMA_FTYPE_MOSTLY_Q8_0;    break;
-                case GGML_TYPE_Q8_KV:   ftype = LLAMA_FTYPE_MOSTLY_Q8_KV;   break;
-                case GGML_TYPE_Q2_K:    ftype = LLAMA_FTYPE_MOSTLY_Q2_K;    break;
-                case GGML_TYPE_Q3_K:    ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M;  break;
-                case GGML_TYPE_Q3_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_R4; break;
-                case GGML_TYPE_Q4_K:    ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M;  break;
-                case GGML_TYPE_Q4_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_R4; break;
-                case GGML_TYPE_Q5_K:    ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M;  break;
-                case GGML_TYPE_Q5_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_R4; break;
-                case GGML_TYPE_Q6_K:    ftype = LLAMA_FTYPE_MOSTLY_Q6_K;    break;
-                case GGML_TYPE_Q6_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_K_R4; break;
-                case GGML_TYPE_Q8_K_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_K_R8; break;
-                case GGML_TYPE_Q8_KV_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV_R8; break;
-                case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break;
-                case GGML_TYPE_IQ2_XXS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4; break;
-                case GGML_TYPE_IQ2_XS:  ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS;  break;
-                case GGML_TYPE_IQ2_XS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS_R4; break;
-                case GGML_TYPE_IQ2_KS:  ftype = LLAMA_FTYPE_MOSTLY_IQ2_KS;  break;
-                case GGML_TYPE_IQ2_S:   ftype = LLAMA_FTYPE_MOSTLY_IQ2_M;   break;
-                case GGML_TYPE_IQ2_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_M_R4;break;
-                case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break;
-                case GGML_TYPE_IQ3_XXS_R4: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4; break;
-                case GGML_TYPE_IQ1_KT:  ftype = LLAMA_FTYPE_MOSTLY_IQ1_KT;  break;
-                case GGML_TYPE_IQ2_KT:  ftype = LLAMA_FTYPE_MOSTLY_IQ2_KT;  break;
-                case GGML_TYPE_IQ3_KT:  ftype = LLAMA_FTYPE_MOSTLY_IQ3_KT;  break;
-                case GGML_TYPE_IQ4_KT:  ftype = LLAMA_FTYPE_MOSTLY_IQ4_KT;  break;
-                case GGML_TYPE_IQ1_S:   ftype = LLAMA_FTYPE_MOSTLY_IQ1_S;   break;
-                case GGML_TYPE_IQ1_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ1_S_R4;break;
-                case GGML_TYPE_IQ1_M_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ1_M_R4;break;
-                case GGML_TYPE_IQ1_M:   ftype = LLAMA_FTYPE_MOSTLY_IQ1_M;   break;
-                case GGML_TYPE_IQ1_BN:  ftype = LLAMA_FTYPE_MOSTLY_IQ1_BN;  break;
-                case GGML_TYPE_IQ2_BN:  ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN;  break;
-                case GGML_TYPE_IQ2_BN_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN_R4;break;
-                case GGML_TYPE_IQ4_NL:  ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL;  break;
-                case GGML_TYPE_IQ4_NL_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL_R4;break;
-                case GGML_TYPE_IQ4_XS_R8:ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS_R8;break;
-                case GGML_TYPE_Q4_0_R8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_R8; break;
-                case GGML_TYPE_Q5_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q5_0_R4; break;
-                case GGML_TYPE_Q6_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_0_R4; break;
-                case GGML_TYPE_Q8_0_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_0_R8; break;
-                case GGML_TYPE_MXFP4:   ftype = LLAMA_FTYPE_MOSTLY_MXFP4;   break;
-                case GGML_TYPE_IQ4_XS:  ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS;  break;
-                case GGML_TYPE_IQ4_KS:  ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS;  break;
-                case GGML_TYPE_IQ4_KS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS_R4;  break;
-                case GGML_TYPE_IQ5_KS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ5_KS_R4;  break;
-                case GGML_TYPE_IQ4_KSS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KSS; break;
-                case GGML_TYPE_IQ5_KS:  ftype = LLAMA_FTYPE_MOSTLY_IQ5_KS;  break;
-                case GGML_TYPE_IQ2_K:   ftype = LLAMA_FTYPE_MOSTLY_IQ2_K;   break;
-                case GGML_TYPE_IQ2_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_K_R4;break;
-                case GGML_TYPE_IQ3_KS:  ftype = LLAMA_FTYPE_MOSTLY_IQ3_KS;  break;
-                case GGML_TYPE_IQ2_KL:  ftype = LLAMA_FTYPE_MOSTLY_IQ2_KL;  break;
-                case GGML_TYPE_IQ3_K:   ftype = LLAMA_FTYPE_MOSTLY_IQ3_K;   break;
-                case GGML_TYPE_IQ3_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ3_K_R4;break;
-                case GGML_TYPE_IQ4_K:   ftype = LLAMA_FTYPE_MOSTLY_IQ4_K;   break;
-                case GGML_TYPE_IQ4_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_K_R4;break;
-                case GGML_TYPE_IQ5_K:   ftype = LLAMA_FTYPE_MOSTLY_IQ5_K;   break;
-                case GGML_TYPE_IQ5_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ5_K_R4;break;
-                case GGML_TYPE_IQ6_K:   ftype = LLAMA_FTYPE_MOSTLY_IQ6_K;   break;
-                case GGML_TYPE_IQ3_S:   ftype = LLAMA_FTYPE_MOSTLY_IQ3_S;   break;
-                case GGML_TYPE_IQ3_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ3_S_R4;break;
-                case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break;
-                case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break;
-                case GGML_TYPE_Q4_0_8_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_8_8; break;
-                default:
-                    {
-                        LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
-                        ftype = LLAMA_FTYPE_ALL_F32;
-                    } break;
-            }
-
-            // this is a way to mark that we have "guessed" the file type
-            ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED);
-
-            {
-                const int kid = gguf_find_key(meta, "general.file_type"); // TODO: use LLM_KV
-                if (kid >= 0) {
-                    ftype = (llama_ftype) gguf_get_val_u32(meta, kid);
-                }
-            }
-
-            LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__);
-
-            for (int i = 0; i < n_kv; i++) {
-                const char * name           = gguf_get_key(meta, i);
-                const enum gguf_type type   = gguf_get_kv_type(meta, i);
-                const std::string type_name =
-                    type == GGUF_TYPE_ARRAY
-                    ? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta, i)), gguf_get_arr_n(meta, i))
-                    : gguf_type_name(type);
-
-                std::string value          = gguf_kv_to_str(meta, i);
-                const size_t MAX_VALUE_LEN = 40;
-                if (value.size() > MAX_VALUE_LEN) {
-                    value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str());
-                }
-                replace_all(value, "\n", "\\n");
-
-                LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str());
-            }
-
-            // print type counts
-            for (auto & kv : n_type) {
-                if (kv.second == 0) {
-                    continue;
-                }
-
-                LLAMA_LOG_INFO("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second);
-            }
-        }
-
-        if (!llama_mmap::SUPPORTED) {
-            LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__);
-            use_mmap = false;
-        }
-        if (repack_tensors) {
-            use_mmap = false;
-        }
-
-        this->use_mmap = use_mmap;
-        this->check_tensors = check_tensors;
-        this->repack_tensors = repack_tensors;
-        this->use_thp = use_thp;
-    }
-
-    ~llama_model_loader() {
-        if (meta) {
-            gguf_free(meta);
-        }
-        for (auto * ctx : contexts) {
-            ggml_free(ctx);
-        }
-    }
-
-    template
-    typename std::enable_if::value, bool>::type
-    get_arr_n(const std::string & key, T & result, const bool required = true) {
-        const int kid = gguf_find_key(meta, key.c_str());
-
-        if (kid < 0) {
-            if (required) {
-                throw std::runtime_error(format("key not found in model: %s", key.c_str()));
-            }
-            return false;
-        }
-
-        struct GGUFMeta::ArrayInfo arr_info =
-            GGUFMeta::GKV::get_kv(meta, kid);
-
-
-        result = arr_info.length;
-        return true;
-    }
-
-    template
-    typename std::enable_if::value, bool>::type
-    get_arr_n(const enum llm_kv kid, T & result, const bool required = true) {
-        return get_arr_n(llm_kv(kid), result, required);
-    }
-
-    template
-    bool get_arr(const std::string & key, std::vector & result, const bool required = true) {
-        const int kid = gguf_find_key(meta, key.c_str());
-
-        if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) {
-            if (required) {
-                throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
-            }
-            return false;
-        }
-
-        struct GGUFMeta::ArrayInfo arr_info =
-            GGUFMeta::GKV::get_kv(meta, kid);
-
-        switch (arr_info.gt) {
-            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break;
-            case GGUF_TYPE_INT32:   GGML_ASSERT(
-                                            (std::is_same::value) ||
-                                            (std::is_same::value));  break;
-            default:
-                throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
-        }
-
-        result.resize(arr_info.length);
-        result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
-
-        return true;
-    }
-
-    template
-    bool get_arr(const std::string & key, std::array & result, const bool required = true) {
-        const int kid = gguf_find_key(meta, key.c_str());
-
-        if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) {
-            if (required) {
-                throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
-            }
-            return false;
-        }
-
-        struct GGUFMeta::ArrayInfo arr_info =
-            GGUFMeta::GKV::get_kv(meta, kid);
-
-        switch (arr_info.gt) {
-            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break;
-            case GGUF_TYPE_INT32:   GGML_ASSERT(
-                                            (std::is_same::value) ||
-                                            (std::is_same::value));  break;
-            default:
-                throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
-        }
-
-        if (arr_info.length > N_MAX) {
-            throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
-        }
-
-        std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
-
-        return true;
-    }
-
-    template
-    bool get_arr(const enum llm_kv kid, T & result, const bool required = true) {
-        return get_arr(llm_kv(kid), result, required);
-    }
-
-    template
-    bool get_key(const std::string & key, T & result, const bool required = true) {
-        auto it = kv_overrides.find(key);
-
-        const struct llama_model_kv_override * override =
-            it != kv_overrides.end() ? &it->second : nullptr;
-
-        const bool found = GGUFMeta::GKV::set(meta, key, result, override);
-
-        if (required && !found) {
-            throw std::runtime_error(format("key not found in model: %s", key.c_str()));
-        }
-
-        return found;
-    }
-
-    template
-    bool get_key(const enum llm_kv kid, T & result, const bool required = true) {
-        return get_key(llm_kv(kid), result, required);
-    }
-
-    // get array of n <= N_MAX elements, or a single element repeated n times
-    template
-    bool get_key_or_arr(const std::string & key, std::array & result, uint32_t n, const bool required = true) {
-        const int kid = gguf_find_key(meta, key.c_str());
-
-        if (kid < 0) {
-            if (required) {
-                throw std::runtime_error(format("key not found in model: %s", key.c_str()));
-            }
-            return false;
-        }
-
-        if (n > N_MAX) {
-            throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str()));
-        }
-
-        if (gguf_get_kv_type(meta, kid) == GGUF_TYPE_ARRAY) {
-            struct GGUFMeta::ArrayInfo arr_info =
-                GGUFMeta::GKV::get_kv(meta, kid);
-
-            if (n != arr_info.length) {
-                throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length));
-            }
-
-            return get_arr(key, result, required);
-        } else {
-            T value;
-
-            bool ok = get_key(key, value, required);
-            if (!ok) {
-                return false;
-            }
-
-            for (uint32_t i = 0; i < n; i++) {
-                result[i] = value;
-            }
-
-            return true;
-        }
-    }
-
-    template
-    bool get_key_or_arr(const enum llm_kv kid, T & result, uint32_t n, const bool required = true) {
-        return get_key_or_arr(llm_kv(kid), result, n, required);
-    }
-
-    std::string get_arch_name() const {
-        return arch_name;
-    }
-
-    enum llm_arch get_arch() const {
-        return llm_kv.arch;
-    }
-
-    const char * get_tensor_name(int i) const {
-        return weights.at(i).tensor->name;
-    }
-
-    const llama_tensor_weight * get_weight(const char * name) const {
-        for (const auto & weight : weights) {
-            if (strcmp(name, weight.tensor->name) == 0) {
-                return &weight;
-            }
-        }
-        return nullptr;
-    }
-
-    const llama_tensor_weight * get_weight(int i) const {
-        return get_weight(get_tensor_name(i));
-    }
-
-    const llama_tensor_weight & require_weight(const char * name) const {
-        const llama_tensor_weight * weight = get_weight(name);
-        if (!weight) {
-            throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name));
-        }
-        return *weight;
-    }
-
-    struct ggml_tensor * get_tensor_meta(const char * name) const {
-        const auto * weight = get_weight(name);
-        if (!weight) {
-            return nullptr;
-        }
-        return weight->tensor;
-    }
-
-    struct ggml_tensor * require_tensor_meta(const char * name) const {
-        struct ggml_tensor * tensor = get_tensor_meta(name);
-        if (!tensor) {
-            throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name));
-        }
-        return tensor;
-    }
-
-    struct ggml_tensor * get_tensor_meta(int i) const {
-        return get_tensor_meta(get_tensor_name(i));
-    }
-
-    struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur, bool duplicated) {
-        struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur);
-        ggml_set_name(tensor, ggml_get_name(cur));
-
-        if (duplicated) {
-            size_data += ggml_nbytes(cur);
-        } else {
-            n_created++;
-        }
-
-        return tensor;
-    }
-
-    const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector & ne, bool required) const {
-        const struct ggml_tensor * cur = get_tensor_meta(name.c_str());
-
-        if (cur == NULL) {
-            if (!required) {
-                return NULL;
-            }
-            throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str()));
-        }
-
-        {
-            bool is_ok = true;
-            for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
-                if ((i < ne.size() && ne[i] != cur->ne[i]) || (i >= ne.size() && cur->ne[i] != 1)) {
-                    is_ok = false;
-                    break;
-                }
-            }
-            if (!is_ok) {
-                throw std::runtime_error(
-                        format("%s: tensor '%s' has wrong shape; expected %s, got %s",
-                            __func__, name.c_str(),
-                            llama_format_tensor_shape(ne).c_str(),
-                            llama_format_tensor_shape(cur).c_str()));
-            }
-        }
-
-        return cur;
-    }
-
-    static const int TENSOR_NOT_REQUIRED = 1 << 0;
-    static const int TENSOR_DUPLICATED   = 1 << 1;
-    static const int TENSOR_SKIP         = 1 << 2;
-
-    struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, int flags = 0) {
-        const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));
-
-        if (cur == NULL) {
-            return NULL;
-        }
-
-        // skip unused tensors
-        if (flags & TENSOR_SKIP) {
-            const size_t nbytes = ggml_nbytes(cur);
-            LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", name.c_str(), nbytes);
-
-            size_data -= nbytes;
-            n_created++;
-
-            return nullptr;
-        }
-
-        return create_tensor_for(ctx, cur, flags & TENSOR_DUPLICATED);
-    }
-
-    struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::vector & ne, size_t offset, bool required = true) {
-        const struct ggml_tensor * cur = check_tensor_dims(name, ne, required);
-
-        if (cur == NULL) {
-            return NULL;
-        }
-
-        if (cur->type != base->type) {
-            throw std::runtime_error(format("%s: tensor '%s' has wrong type; expected %s, got %s", __func__, name.c_str(), ggml_type_name(base->type), ggml_type_name(cur->type)));
-        }
-
-        std::array dims;
-        for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
-            dims[i] = i < ne.size() ? ne[i] : 1;
-        }
-
-        struct ggml_tensor * tensor = ggml_view_4d(ctx, base,
-                                        dims[0], dims[1], dims[2], dims[3],
-                                        cur->nb[1], cur->nb[2], cur->nb[3],
-                                        offset);
-
-        ggml_set_name(tensor, name.c_str());
-
-        n_created++;
-
-        return tensor;
-    }
-
-    void done_getting_tensors() const {
-        if (n_created != n_tensors) {
-            throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
-        }
-    }
-
-    void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr, bool use_thp = false) {
-        if (use_mmap) {
-            mappings.reserve(files.size());
-            mmaps_used.reserve(files.size());
-            for (const auto & file : files) {
-                std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, ggml_is_numa(), use_thp));
-                mmaps_used.emplace_back(mapping->size, 0);
-                if (mlock_mmaps) {
-                    std::unique_ptr mlock_mmap(new llama_mlock());
-                    mlock_mmap->init(mapping->addr);
-                    mlock_mmaps->emplace_back(std::move(mlock_mmap));
-                }
-                mappings.emplace_back(std::move(mapping));
-            }
-        }
-
-        // compute the total size of all tensors for progress reporting
-        for (auto & w : weights) {
-            size_data += ggml_nbytes(w.tensor);
-        }
-    }
-
-    void get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const {
-        GGML_ASSERT(!mappings.empty());
-        const auto & mapping = mappings.at(idx);
-
-        *first = mapping->size;
-        *last  = 0;
-        *addr = mapping->addr;
-        for (ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor; tensor = ggml_get_next_tensor(ctx, tensor)) {
-            try {
-                const auto * weight = get_weight(ggml_get_name(tensor));
-                if (!weight) {
-                    continue;
-                }
-                if (weight->idx != idx) {
-                    continue;
-                }
-                *first = std::min(*first, weight->offs);
-                *last  = std::max(*last,  weight->offs + ggml_nbytes(tensor));
-            } catch(...) {
-                // the tensor is not in the model
-            }
-        }
-    }
-
-    // for backwards compatibility, does not support ggml-backend
-    void load_data_for(struct ggml_tensor * cur) const {
-        const auto & w = require_weight(ggml_get_name(cur));
-
-        if (use_mmap) {
-            const auto & mapping = mappings.at(w.idx);
-            if (cur->data == nullptr) {
-                cur->data = (uint8_t *)mapping->addr + w.offs;
-            } else {
-                memcpy(cur->data, (uint8_t *)mapping->addr + w.offs, ggml_nbytes(cur));
-            }
-        } else {
-            GGML_ASSERT(cur->data != nullptr);
-            GGML_ASSERT(w.idx < files.size());
-            const auto & file = files.at(w.idx);
-            file->seek(w.offs, SEEK_SET);
-            file->read_raw(cur->data, ggml_nbytes(cur));
-        }
-
-        if (check_tensors && !ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) {
-            throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
-        }
-    }
-
-    size_t size_done = 0;
-    size_t size_data = 0;
-    std::vector> mmaps_used;
-
-    // Returns false if cancelled by progress_callback
-    bool load_all_data(
-            struct ggml_context * ctx,
-            llama_buf_map & bufs_mmap,
-            llama_mlocks * lmlocks,
-            llama_progress_callback progress_callback,
-            void * progress_callback_user_data) {
-        GGML_ASSERT(size_data != 0 && "call init_mappings() first");
-
-        std::vector> read_buf;
-        std::vector>> validation_result;
-
-#if defined(GGML_USE_CUDA)
-        // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives.
-        // NVMe raid configurations might require more / larger buffers.
-        constexpr size_t n_buffers = 4;
-        constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB
-
-        std::vector host_buffers;
-        std::vector host_ptrs;
-        std::vector events;
-        size_t buffer_idx = 0; // buffer to use for async loads
-
-        ggml_backend_t cuda_backend = nullptr;
-        if (!use_mmap && !check_tensors) {
-            // When not using mmaped io use async uploads from pinned memory to GPU memory.
-            // First determine if the CUDA backend is active, and if so, determine the device ID.
-            ggml_backend_buffer_t buf = bufs_mmap.count(0) ? bufs_mmap.at(0) : nullptr;
-            if (buf) {
-                ggml_backend_buffer_type_t buffer_type = ggml_backend_buffer_get_type(buf);
-                for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
-                    auto * cuda_buffer_type = ggml_backend_cuda_buffer_type(i);
-                    if (buffer_type == cuda_buffer_type) {
-                        cuda_backend = ggml_backend_cuda_init(i);
-                        break;
-                    }
-                }
-            }
-
-            // If the cuda backend is active create pinned memory buffers and events for synchronisation.
-            if (cuda_backend) {
-                for (size_t idx = 0; idx < n_buffers; ++idx) {
-                    host_buffers.emplace_back(ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buffer_size));
-                    host_ptrs.emplace_back(ggml_backend_buffer_get_base(host_buffers[idx]));
-                    events.emplace_back(ggml_backend_event_new(cuda_backend));
-                }
-            }
-        }
-#endif
-
-        for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) {
-            const auto * weight = get_weight(ggml_get_name(cur));
-            if (weight == nullptr) {
-                // this can happen with split experts models
-                continue;
-            }
-
-            if (progress_callback) {
-                if (!progress_callback((float) size_done / size_data, progress_callback_user_data)) {
-                    return false;
-                }
-            }
-
-            size_t n_size = ggml_nbytes(cur);
-
-            if (use_mmap) {
-                const auto & mapping = mappings.at(weight->idx);
-                ggml_backend_buffer_t buf_mmap = nullptr;
-                if (bufs_mmap.count(weight->idx)) {
-                    buf_mmap = bufs_mmap.at(weight->idx);
-                }
-                uint8_t * data = (uint8_t *) mapping->addr + weight->offs;
-
-                if (check_tensors) {
-                    validation_result.emplace_back(std::async(std::launch::async, [cur, data, n_size] {
-                        return std::make_pair(cur, ggml_validate_row_data(cur->type, data, n_size));
-                    }));
-                }
-
-                GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated
-                if (buf_mmap && cur->data == nullptr) {
-                    ggml_backend_tensor_alloc(buf_mmap, cur, data);
-                    if (lmlocks) {
-                        const auto & lmlock = lmlocks->at(weight->idx);
-                        lmlock->grow_to(weight->offs + n_size);
-                    }
-
-                    auto & mmap_used = mmaps_used[weight->idx];
-                    mmap_used.first  = std::min(mmap_used.first,  weight->offs);
-                    mmap_used.second = std::max(mmap_used.second, weight->offs + n_size);
-                } else {
-                    ggml_backend_tensor_set(cur, data, 0, n_size);
-                }
-            } else {
-                GGML_ASSERT(weight->idx < files.size());
-                const auto & file = files.at(weight->idx);
-                if (ggml_backend_buffer_is_host(cur->buffer)) {
-                    file->seek(weight->offs, SEEK_SET);
-                    file->read_raw(cur->data, n_size);
-                    if (check_tensors) {
-                        validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] {
-                            return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size));
-                        }));
-                    }
-                } else {
-#if defined(GGML_USE_CUDA)
-                    // If cuda_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU.
-                    if (cuda_backend) {
-                        file->seek(weight->offs, SEEK_SET);
-
-                        size_t bytes_read = 0;
-
-                        while (bytes_read < n_size) {
-                            size_t read_iteration = std::min(buffer_size, n_size - bytes_read);
-
-                            ggml_backend_event_synchronize(events[buffer_idx]);
-                            file->read_raw(host_ptrs[buffer_idx], read_iteration);
-                            ggml_backend_tensor_set_async(cuda_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration);
-                            ggml_backend_event_record(events[buffer_idx]);
-
-                            bytes_read += read_iteration;
-                            ++buffer_idx;
-                            buffer_idx %= n_buffers;
-                        }
-                    }
-                    else
-#endif
-                    {
-                        read_buf.resize(n_size);
-                        file->seek(weight->offs, SEEK_SET);
-                        file->read_raw(read_buf.data(), n_size);
-                        ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
-                        if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
-                            throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
-                        }
-                    }
-                }
-            }
-
-            size_done += n_size;
-        }
-
-#if defined(GGML_USE_CUDA)
-        // free temporary resources used for async cuda uploads
-        if (cuda_backend) {
-            for (size_t idx = 0; idx < n_buffers;++idx) {
-                ggml_backend_event_synchronize(events[idx]);
-                ggml_backend_event_free(events[idx]);
-                ggml_backend_buffer_free(host_buffers[idx]);
-            }
-            ggml_backend_free(cuda_backend);
-        }
-#endif
-
-        // check validation results
-        bool validation_failed = false;
-        for (auto & future : validation_result) {
-            auto result = future.get();
-            if (!result.second) {
-                LLAMA_LOG_ERROR("%s: tensor '%s' has invalid data\n", __func__, ggml_get_name(result.first));
-                validation_failed = true;
-            }
-        }
-        if (validation_failed) {
-            throw std::runtime_error("found tensors with invalid data");
-        }
-
-        // check if this is the last call and do final cleanup
-        if (size_done >= size_data) {
-            // unmap offloaded tensors and metadata
-            if (use_mmap) {
-                for (uint32_t idx = 0; idx < mappings.size(); idx++) {
-                    const auto & mmap_used = mmaps_used.at(idx);
-                    auto & mapping = mappings.at(idx);
-                    mapping->unmap_fragment(0, mmap_used.first);
-                    if (mmap_used.second != 0) {
-                        mapping->unmap_fragment(mmap_used.second, mapping->size);
-                    }
-                }
-            }
-            if (progress_callback) {
-                // Even though the model is done loading, we still honor
-                // cancellation since we need to free allocations.
-                return progress_callback(1.0f, progress_callback_user_data);
-            }
-        }
-
-        return true;
-    }
-};
+// TODO: update when needed or think of some clever automatic way to do this
+static size_t llama_model_max_nodes(const llama_model & /*model*/) {
+    //if (model.arch == LLM_ARCH_LLAMA && model.hparams.n_layer > ??) { // llama-3 405B
+    //    return 32768;
+    //}
 
-template<>
-bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
-    uint32_t tmp;
-    const bool found = get_key(kid, tmp, required);
-    if (found) {
-        result = (enum llama_pooling_type) tmp;
-    } else {
-        result = LLAMA_POOLING_TYPE_UNSPECIFIED;
-    }
-    return found;
+    return 65536;
 }
 
-
 //
 // load LLaMA models
 //
@@ -6283,728 +4258,54 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
-	case LLM_ARCH_DOTS1:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT,   hparams.n_layer_dense_lead);
-                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,  hparams.n_ff_exp);
-                ml.get_key(LLM_KV_EXPERT_SHARED_COUNT,         hparams.n_expert_shared);
-                ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE,        hparams.expert_weights_scale);
-                ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM,         hparams.expert_weights_norm, false);
-                ml.get_key(LLM_KV_EXPERT_GATING_FUNC,          hparams.expert_gating_func, false);
-                switch (hparams.n_layer) {
-                    case 62: model.type = e_model::MODEL_142B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_HUNYUAN_MOE:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,       hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        hparams.n_ff_exp);
-                ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
-
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_80B_A13B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        default: (void)0;
-    }
-
-    model.ftype = ml.ftype;
-
-    if (hparams.f_max_alibi_bias > 0.0f) {
-        hparams.use_alibi = true;
-    }
-
-    hparams.rope_type = llama_rope_type(&model);
-}
-
-static void llm_load_vocab(
-        llama_model_loader & ml,
-        llama_model & model) {
-    auto & vocab = model.vocab;
-
-    struct gguf_context * ctx = ml.meta;
-
-    const auto kv = LLM_KV(model.arch);
-
-    // determine vocab type
-    {
-        std::string tokenizer_model;
-        std::string tokenizer_pre;
-
-        ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
-        ml.get_key(LLM_KV_TOKENIZER_PRE,   tokenizer_pre, false);
-
-        if (tokenizer_model == "no_vocab") {
-            vocab.type = LLAMA_VOCAB_TYPE_NONE;
-
-            // default special tokens
-            vocab.special_bos_id  = -1;
-            vocab.special_eos_id  = -1;
-            vocab.special_unk_id  = -1;
-            vocab.special_sep_id  = -1;
-            vocab.special_pad_id  = -1;
-            vocab.special_cls_id  = -1;
-            vocab.special_mask_id = -1;
-            vocab.linefeed_id     = -1;
-
-            return;
-        } else if (tokenizer_model == "llama") {
-            vocab.type = LLAMA_VOCAB_TYPE_SPM;
-
-            // default special tokens
-            vocab.special_bos_id  = 1;
-            vocab.special_eos_id  = 2;
-            vocab.special_unk_id  = 0;
-            vocab.special_sep_id  = -1;
-            vocab.special_pad_id  = -1;
-            vocab.special_cls_id  = -1;
-            vocab.special_mask_id = -1;
-        } else if (tokenizer_model == "bert") {
-            vocab.type = LLAMA_VOCAB_TYPE_WPM;
-
-            // default special tokens
-            vocab.special_bos_id  = -1;
-            vocab.special_eos_id  = -1;
-            vocab.special_unk_id  = 100;
-            vocab.special_sep_id  = 102;
-            vocab.special_pad_id  = 0;
-            vocab.special_cls_id  = 101;
-            vocab.special_mask_id = 103;
-        } else if (tokenizer_model == "gpt2") {
-            vocab.type = LLAMA_VOCAB_TYPE_BPE;
-
-            // read bpe merges and populate bpe ranks
-            const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
-            if (merges_keyidx == -1) {
-                throw std::runtime_error("cannot find tokenizer merges in model file\n");
-            }
-
-            const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
-            for (int i = 0; i < n_merges; i++) {
-                const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
-                GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
-
-                std::string first;
-                std::string second;
-
-                const size_t pos = word.find(' ', 1);
-
-                if (pos != std::string::npos) {
-                    first  = word.substr(0, pos);
-                    second = word.substr(pos + 1);
-                }
-
-                vocab.bpe_ranks.emplace(std::make_pair(first, second), i);
-            }
-
-            // default special tokens
-            if(model.arch == LLM_ARCH_DOTS1) {
-                vocab.special_bos_id = -1;
-            }
-            else {
-                vocab.special_bos_id  = 11;
-            }
-            vocab.special_eos_id  = 11;
-            vocab.special_unk_id  = -1;
-            vocab.special_sep_id  = -1;
-            vocab.special_pad_id  = -1;
-            vocab.special_cls_id  = -1;
-            vocab.special_mask_id = -1;
-        } else if (tokenizer_model == "t5") {
-            vocab.type = LLAMA_VOCAB_TYPE_UGM;
-
-            // default special tokens
-            vocab.special_bos_id  = -1;
-            vocab.special_eos_id  = 1;
-            vocab.special_unk_id  = 2;
-            vocab.special_sep_id  = -1;
-            vocab.special_pad_id  = 0;
-            vocab.special_cls_id  = -1;
-            vocab.special_mask_id = -1;
-
-            const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
-            if (precompiled_charsmap_keyidx != -1) {
-                size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
-                const char * precompiled_charsmap = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
-                vocab.precompiled_charsmap.assign(precompiled_charsmap, precompiled_charsmap + n_precompiled_charsmap);
-#ifdef IS_BIG_ENDIAN
-                // correct endiannes of data in precompiled_charsmap binary blob
-                uint32_t * xcda_blob_size = (uint32_t *) &vocab.precompiled_charsmap[0];
-                *xcda_blob_size = __builtin_bswap32(*xcda_blob_size);
-                assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap);
-                size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t);
-                uint32_t * xcda_array = (uint32_t *) &vocab.precompiled_charsmap[sizeof(uint32_t)];
-                for (size_t i = 0; i < xcda_array_size; ++i) {
-                    xcda_array[i] = __builtin_bswap32(xcda_array[i]);
-                }
-#endif
-            }
-        } else {
-            throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
-        }
-
-        // for now, only BPE models have pre-tokenizers
-        if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
-            vocab.tokenizer_add_space_prefix = false;
-            vocab.tokenizer_clean_spaces = true;
-            if (tokenizer_pre.empty()) {
-                //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-                // OK - I don't feel like recreati8ng the LLaMA-v3 models. Considering that, at least for now,
-                // LLaMA-v3 is the only model wehere we end up here, let's just force the pre-tokanizer to be
-                // llama3.
-                tokenizer_pre = "llama3";
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
-                LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'llama3'\n", __func__);
-                LLAMA_LOG_WARN("%s:                                             \n", __func__);
-                LLAMA_LOG_WARN("%s: ************************************        \n", __func__);
-                LLAMA_LOG_WARN("%s: GENERATION QUALITY MAY BE DEGRADED!         \n", __func__);
-                LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL             \n", __func__);
-                LLAMA_LOG_WARN("%s: ************************************        \n", __func__);
-                LLAMA_LOG_WARN("%s:                                             \n", __func__);
-                //vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-                //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-            } else if (tokenizer_pre == "default") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            } else if (
-                    tokenizer_pre == "llama3"   ||
-                    tokenizer_pre == "llama-v3" ||
-                    tokenizer_pre == "llama-bpe") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
-                vocab.tokenizer_ignore_merges = true;
-                vocab.tokenizer_add_bos = true;
-            } else if (
-                    tokenizer_pre == "deepseek-llm") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                    tokenizer_pre == "deepseek-coder") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                    tokenizer_pre == "deepseek-v3") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                    tokenizer_pre == "falcon") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON;
-            } else if (tokenizer_pre == "falcon3") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON_3;
-            } else if (tokenizer_pre == "falcon_e") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON_E;
-            } else if (
-                    tokenizer_pre == "mpt") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_MPT;
-            } else if (
-                    tokenizer_pre == "starcoder") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STARCODER;
-            } else if (
-                    tokenizer_pre == "gpt-2"   ||
-                    tokenizer_pre == "phi-2"   ||
-                    tokenizer_pre == "jina-es" ||
-                    tokenizer_pre == "jina-de" ||
-                    tokenizer_pre == "jina-v2-es" ||
-                    tokenizer_pre == "jina-v2-de" ||
-                    tokenizer_pre == "jina-v2-code") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2;
-            } else if (
-                    tokenizer_pre == "refact") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_REFACT;
-            } else if (
-                tokenizer_pre == "command-r") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "qwen2" || tokenizer_pre == "deepseek-r1-qwen") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "stablelm2") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STABLELM2;
-            } else if (
-                tokenizer_pre == "olmo") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_OLMO;
-            } else if (
-                tokenizer_pre == "dbrx") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX;
-            } else if (
-                tokenizer_pre == "smaug-bpe") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG;
-            } else if (
-                tokenizer_pre == "poro-chat") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "glm4" ||
-                tokenizer_pre == "chatglm-bpe") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
-                vocab.special_bos_id  = -1;
-            } else if (
-                tokenizer_pre == "viking") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "jais") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
-            } else if (
-                tokenizer_pre == "tekken") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TEKKEN;
-                vocab.tokenizer_clean_spaces = false;
-                vocab.tokenizer_ignore_merges = true;
-                vocab.tokenizer_add_bos = true;
-            } else if (
-                tokenizer_pre == "smollm") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMOLLM;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "codeshell") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL;
-            } else if (
-                tokenizer_pre == "gpt-4o" ||
-                tokenizer_pre == "llama4") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT4O;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "superbpe") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SUPERBPE;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "trillion") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TRILLION;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "bailingmoe") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "seed-coder") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "hunyuan") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "kimi-k2") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
-                vocab.tokenizer_clean_spaces = false;
-            } else {
-                throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
-            }
-        } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            vocab.tokenizer_add_space_prefix = true;
-            vocab.tokenizer_clean_spaces = false;
-            vocab.tokenizer_add_bos = true;
-            vocab.tokenizer_add_eos = false;
-        } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            vocab.tokenizer_add_space_prefix = false;
-            vocab.tokenizer_clean_spaces = true;
-            vocab.tokenizer_add_bos = true;
-            vocab.tokenizer_add_eos = false;
-        } else if (vocab.type == LLAMA_VOCAB_TYPE_UGM) {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            vocab.tokenizer_add_bos = false;
-            vocab.tokenizer_add_eos = true;
-        } else {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-        }
-
-        ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX,      vocab.tokenizer_add_space_prefix,         false);
-        ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.tokenizer_remove_extra_whitespaces, false);
-    }
-
-    const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
-    if (token_idx == -1) {
-        throw std::runtime_error("cannot find tokenizer vocab in model file\n");
-    }
-
-    const float * scores = nullptr;
-    const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
-    if (score_idx != -1) {
-        scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
-    }
-
-    const int * toktypes = nullptr;
-    const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
-    if (toktype_idx != -1) {
-        toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
-    }
-
-    const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
-
-    vocab.id_to_token.resize(n_vocab);
-
-    for (uint32_t i = 0; i < n_vocab; i++) {
-        std::string word = gguf_get_arr_str(ctx, token_idx, i);
-        GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
-
-        vocab.token_to_id[word] = i;
-        vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size());
-
-        auto & token_data = vocab.id_to_token[i];
-        token_data.text  = std::move(word);
-        token_data.score = scores ? scores[i] : 0.0f;
-        token_data.attr  = LLAMA_TOKEN_ATTR_NORMAL;
-
-        if (toktypes) {  //TODO: remove, required until per token attributes are available from GGUF file
-            switch(toktypes[i]) {
-                case LLAMA_TOKEN_TYPE_UNKNOWN:      token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN;      break;
-                case LLAMA_TOKEN_TYPE_UNUSED:       token_data.attr = LLAMA_TOKEN_ATTR_UNUSED;       break;
-                case LLAMA_TOKEN_TYPE_NORMAL:       token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;       break;
-                case LLAMA_TOKEN_TYPE_CONTROL:      token_data.attr = LLAMA_TOKEN_ATTR_CONTROL;      break;
-                case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break;
-                case LLAMA_TOKEN_TYPE_BYTE:         token_data.attr = LLAMA_TOKEN_ATTR_BYTE;         break;
-                case LLAMA_TOKEN_TYPE_UNDEFINED:    token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED;    break;
-                default:                            token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED;    break;
-            }
-        }
-    }
-    GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
-
-    // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
-    if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
-        // For Fill-In-the-Middle (FIM)/infill models which where converted
-        // prior to support of FIM special tokens in GGUF, the following
-        // will allow those models to continue to work. The general names
-        // of the known models are currently CodeLlama (LLM_ARCH_LLAMA) and
-        // CodeGemma (LLM_ARCH_GEMMA). This can potentially be removed once
-        // new versions of these models have been published.
-        std::string gen_name;
-        ml.get_key(LLM_KV_GENERAL_NAME, gen_name, false);
-
-        std::transform(gen_name.begin(), gen_name.end(), gen_name.begin(),
-            [](unsigned char c){ return std::tolower(c); });
-
-        if (gen_name.find("code") != std::string::npos) {
-            if (model.arch == LLM_ARCH_LLAMA
-              && 32010 < vocab.id_to_token.size()
-              && vocab.id_to_token[32007].text.find("
") != std::string::npos
-              && vocab.id_to_token[32008].text.find("") != std::string::npos
-              && vocab.id_to_token[32009].text.find("") != std::string::npos
-              && vocab.id_to_token[32010].text.find("") != std::string::npos) {
-                vocab.special_prefix_id = 32007;
-                vocab.special_suffix_id = 32008;
-                vocab.special_middle_id = 32009;
-                vocab.special_eot_id    = 32010;
-            } else if (model.arch == LLM_ARCH_GEMMA
-              && 107 < vocab.id_to_token.size()
-              && vocab.id_to_token[67].text == "<|fim_prefix|>"
-              && vocab.id_to_token[69].text == "<|fim_suffix|>"
-              && vocab.id_to_token[68].text == "<|fim_middle|>"
-              && vocab.id_to_token[107].text == "") {
-                vocab.special_prefix_id = 67;
-                vocab.special_suffix_id = 69;
-                vocab.special_middle_id = 68;
-                // TODO: this is not EOT, it is "file separator" token, needs fix
-                //       https://huggingface.co/google/codegemma-7b-it/blob/9b1d9231388358c04d90bd003458f5070d97db44/tokenizer_config.json#L565-L572
-                //vocab.special_eot_id    = 70;
-                vocab.special_eot_id    = 107;
-            }
-        }
-        try {
-            vocab.linefeed_id = llama_byte_to_token_impl(vocab, '\n');
-        } catch (const std::exception & e) {
-            LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what());
-            vocab.linefeed_id = vocab.special_pad_id;
-        }
-    } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
-        vocab.linefeed_id = vocab.special_pad_id;
-    } else {
-        const std::vector ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A
-        GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
-        vocab.linefeed_id = ids[0];
-    }
-
-    // special tokens
-    {
-        const std::vector> special_token_types = {
-            { LLM_KV_TOKENIZER_BOS_ID,    vocab.special_bos_id    },
-            { LLM_KV_TOKENIZER_EOS_ID,    vocab.special_eos_id    },
-            { LLM_KV_TOKENIZER_EOT_ID,    vocab.special_eot_id    },
-            { LLM_KV_TOKENIZER_EOM_ID,    vocab.special_eom_id    },
-            { LLM_KV_TOKENIZER_UNK_ID,    vocab.special_unk_id    },
-            { LLM_KV_TOKENIZER_SEP_ID,    vocab.special_sep_id    },
-            { LLM_KV_TOKENIZER_PAD_ID,    vocab.special_pad_id    },
-            { LLM_KV_TOKENIZER_CLS_ID,    vocab.special_cls_id    },
-            { LLM_KV_TOKENIZER_MASK_ID,   vocab.special_mask_id   },
-
-            { LLM_KV_TOKENIZER_FIM_PRE_ID, vocab.special_fim_pre_id },
-            { LLM_KV_TOKENIZER_FIM_SUF_ID, vocab.special_fim_suf_id },
-            { LLM_KV_TOKENIZER_FIM_MID_ID, vocab.special_fim_mid_id },
-            { LLM_KV_TOKENIZER_FIM_PAD_ID, vocab.special_fim_pad_id },
-            { LLM_KV_TOKENIZER_FIM_REP_ID, vocab.special_fim_rep_id },
-            { LLM_KV_TOKENIZER_FIM_SEP_ID, vocab.special_fim_sep_id },
-
-            { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_prefix_id },
-            { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id },
-            { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
-        };
-
-        for (const auto & it : special_token_types) {
-            const std::string & key = kv(std::get<0>(it));
-            int32_t & id = std::get<1>(it);
-
-            uint32_t new_id;
-            if (!ml.get_key(std::get<0>(it), new_id, false)) {
-                continue;
-            }
-            if (new_id >= vocab.id_to_token.size()) {
-                LLAMA_LOG_WARN("%s: bad special token: '%s' = %ud, using default id %d\n",
-                    __func__, key.c_str(), new_id, id);
-            } else {
-                id = new_id;
-            }
-        }
-
-        // Handle add_bos_token and add_eos_token
-        {
-            bool temp = true;
-
-            if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
-                vocab.tokenizer_add_bos = temp;
-            }
-            if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
-                vocab.tokenizer_add_eos = temp;
-            }
-        }
-
-        // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc.
-        //
-        // TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOT_ID
-        //       for now, we apply this workaround to find the EOT token based on its text
-        if (vocab.special_eot_id == -1) {
-            for (const auto & t : vocab.token_to_id) {
-                if (
-                        // TODO: gemma "" is exported as a normal token, so the following check does not work
-                        //       need to fix convert script
-                        //vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
-                        (t.first == "<|eot_id|>" ||
-                         t.first == "<|im_end|>" ||
-                         t.first == "<|end|>" ||
-                         t.first == "" ||
-                         t.first == "<|endoftext|>"
-                        )
-                   ) {
-                    vocab.special_eot_id = t.second;
-                    break;
-                }
-            }
-        }
-
-        // find EOM token: "<|eom_id|>"
-        //
-        // TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOM_ID
-        //       for now, we apply this workaround to find the EOM token based on its text
-        if (vocab.special_eom_id == -1) {
-            const auto & t = vocab.token_to_id.find("<|eom_id|>");
-            if (t != vocab.token_to_id.end()) {
-                vocab.special_eom_id = t->second;
-            }
-        }
-
-        for (const auto & t : vocab.token_to_id) {
-            // find FIM_PRE token: "<|fim_prefix|>", "", "
", etc.
-            if (vocab.special_fim_pre_id == -1) {
-                if (false
-                        || t.first == "<|fim_prefix|>"  // Qwen
-                        || t.first == ""
-                        || t.first == ""    // Granite
-                        || t.first == "<|fim▁begin|>" // DeepSeek
-                        || t.first == "
"
-                        || t.first == "▁
"          // CodeLlama
-                        || t.first == "<|code_prefix|>" // GLM-4.5
-                        ) {
-                    vocab.special_fim_pre_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                                vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-
-            // find FIM_SUF token: "<|fim_suffix|>", "", "", etc.
-            if (vocab.special_fim_suf_id == -1) {
-                if (false
-                        || t.first == "<|fim_suffix|>" // Qwen
-                        || t.first == ""
-                        || t.first == ""   // Granite
-                        || t.first == "<|fim▁hole|>" // DeepSeek
-                        || t.first == ""
-                        || t.first == "▁"         // CodeLlama
-                        || t.first == "<|code_suffix|>" // GLM-4.5
-                        ) {
-                    vocab.special_fim_suf_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-
-            // find FIM_MID token: "<|fim_middle|>", "", "", etc.
-            if (vocab.special_fim_mid_id == -1) {
-                if (false
-                        || t.first == "<|fim_middle|>" // Qwen
-                        || t.first == ""
-                        || t.first == ""   // Granite
-                        || t.first == "<|fim▁end|>"  // DeepSeek
-                        || t.first == ""
-                        || t.first == "▁"         // CodeLlama
-                        || t.first == "<|code_middle|>" // GLM-4.5
-                        ) {
-                    vocab.special_fim_mid_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-
-            // find FIM_PAD token: "<|fim_pad|>", "", "", etc.
-            if (vocab.special_fim_pad_id == -1) {
-                if (false
-                        || t.first == "<|fim_pad|>" // Qwen
-                        || t.first == ""
-                        || t.first == ""   // Granite
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_pad_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-
-            // find FIM_REP token: "<|fim_repo|>", "", "", etc.
-            if (vocab.special_fim_rep_id == -1) {
-                if (false
-                        || t.first == "<|fim_repo|>"  // Qwen
-                        || t.first == "<|repo_name|>"
-                        || t.first == ""
-                        || t.first == ""
-                        || t.first == ""    // Granite
-                        ) {
-                    vocab.special_fim_rep_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-
-            // find FIM_SEP token: "<|file_sep|>"
-            if (vocab.special_fim_sep_id == -1) {
-                if (false
-                        || t.first == "<|file_sep|>" // Qwen
-                        ) {
-                    vocab.special_fim_sep_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-        }
-
-    }
-
-    // build special tokens cache
-    {
-        for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
-            if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
-                vocab.cache_special_tokens.push_back(id);
-            }
-        }
-
-        std::sort(vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
-            [&] (const llama_vocab::id a, const llama_vocab::id b) {
-                return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
-            }
-        );
-
-        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t)vocab.cache_special_tokens.size());
-    }
-
-    // build token to piece cache
-    {
-        size_t size_cache = 0;
-
-        std::vector cache_token_to_piece(n_vocab);
-
-        for (uint32_t id = 0; id < n_vocab; ++id) {
-            cache_token_to_piece[id] = llama_token_to_piece(&model, id, true);
-
-            size_cache += cache_token_to_piece[id].size();
-        }
-
-        std::swap(vocab.cache_token_to_piece, cache_token_to_piece);
-
-        LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
-    }
-
-    // Handle per token attributes
-    //NOTE: Each model customizes per token attributes.
-    //NOTE: Per token attributes are missing from the GGUF file.
-    //TODO: Extract attributes from GGUF file.
-    {
-        auto _contains_any = [] (const std::string &str, const std::vector &substrs) -> bool {
-            for (auto substr : substrs) {
-                if (str.find(substr) < std::string::npos) {
-                    return true;
+	    case LLM_ARCH_DOTS1:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT,   hparams.n_layer_dense_lead);
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,  hparams.n_ff_exp);
+                ml.get_key(LLM_KV_EXPERT_SHARED_COUNT,         hparams.n_expert_shared);
+                ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE,        hparams.expert_weights_scale);
+                ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM,         hparams.expert_weights_norm, false);
+                ml.get_key(LLM_KV_EXPERT_GATING_FUNC,          hparams.expert_gating_func, false);
+                switch (hparams.n_layer) {
+                    case 62: model.type = e_model::MODEL_142B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
                 }
-            }
-            return false;
-        };
+            } break;
+        case LLM_ARCH_HUNYUAN_MOE:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,       hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        hparams.n_ff_exp);
+                ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
 
-        auto _set_tokenid_attr = [&] (const llama_vocab::id id, llama_token_attr attr, bool value) {
-            uint32_t current = vocab.id_to_token.at(id).attr;
-            current = value ? (current | attr) : (current & ~attr);
-            vocab.id_to_token[id].attr = (llama_token_attr) current;
-        };
+                switch (hparams.n_layer) {
+                    case 32: model.type = e_model::MODEL_80B_A13B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
+        case LLM_ARCH_OPENAI_MOE:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,  hparams.n_ff_exp);
+                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa);
 
-        auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
-            _set_tokenid_attr(vocab.token_to_id.at(token), attr, value);
-        };
+                //TODO OAI_MOE: SWA
+                //hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+                //hparams.set_swa_pattern(2);
 
-        std::string model_name;
-        std::string tokenizer_pre;
+                // TODO: switch (hparams.n_layer)
 
-        ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
-        ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
+            } break;
+        default: (void)0;
+    }
 
-        // model name to lowercase
-        std::transform(model_name.begin(), model_name.end(), model_name.begin(),
-            [] (const std::string::value_type x) {
-                return std::tolower(x);
-            }
-        );
+    model.ftype = ml.ftype;
 
-        // set attributes by model/tokenizer name
-        if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
-            _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
-        } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
-            for (auto id : vocab.cache_special_tokens) {
-                _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
-            }
-            for (auto token : {""}) {
-                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
-            }
-            for (auto token : {"", "", "<|endoftext|>"}) {
-                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
-            }
-        }
+    if (hparams.f_max_alibi_bias > 0.0f) {
+        hparams.use_alibi = true;
     }
+
+    hparams.rope_type = llama_rope_type(&model);
 }
 
 static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
@@ -7045,10 +4346,6 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     // hparams
     LLAMA_LOG_INFO("%s: format           = %s\n",     __func__, llama_file_version_name(ml.fver));
     LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, LLM_ARCH_NAMES.at(model.arch));
-    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, llama_model_vocab_type_name(vocab.type));
-    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, hparams.n_vocab);
-    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (int) vocab.bpe_ranks.size());
-    LLAMA_LOG_INFO("%s: vocab_only       = %d\n",     __func__, hparams.vocab_only);
 
     if (!hparams.vocab_only) {
         LLAMA_LOG_INFO("%s: n_ctx_train      = %u\n",     __func__, hparams.n_ctx_train);
@@ -7128,31 +4425,6 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     // general kv
     LLAMA_LOG_INFO("%s: general.name     = %s\n",    __func__, model.name.c_str());
 
-    // special tokens
-    if (vocab.special_bos_id    != -1) { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, vocab.special_bos_id,  vocab.id_to_token[vocab.special_bos_id].text.c_str() );  }
-    if (vocab.special_eos_id    != -1) { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, vocab.special_eos_id,  vocab.id_to_token[vocab.special_eos_id].text.c_str() );  }
-    if (vocab.special_unk_id    != -1) { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, vocab.special_unk_id,  vocab.id_to_token[vocab.special_unk_id].text.c_str() );  }
-    if (vocab.special_sep_id    != -1) { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, vocab.special_sep_id,  vocab.id_to_token[vocab.special_sep_id].text.c_str() );  }
-    if (vocab.special_pad_id    != -1) { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, vocab.special_pad_id,  vocab.id_to_token[vocab.special_pad_id].text.c_str() );  }
-    if (vocab.special_cls_id    != -1) { LLAMA_LOG_INFO( "%s: CLS token        = %d '%s'\n", __func__, vocab.special_cls_id,  vocab.id_to_token[vocab.special_cls_id].text.c_str() );  }
-    if (vocab.special_mask_id   != -1) { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
-
-    if (vocab.linefeed_id       != -1) { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, vocab.linefeed_id,       vocab.id_to_token[vocab.linefeed_id].text.c_str() );       }
- 
-    if (vocab.special_fim_pre_id != -1) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, vocab.special_fim_pre_id, vocab.id_to_token.at(vocab.special_fim_pre_id).text.c_str() ); }
-    if (vocab.special_fim_suf_id != -1) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, vocab.special_fim_suf_id, vocab.id_to_token.at(vocab.special_fim_suf_id).text.c_str() ); }
-    if (vocab.special_fim_mid_id != -1) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, vocab.special_fim_mid_id, vocab.id_to_token.at(vocab.special_fim_mid_id).text.c_str() ); }
-    if (vocab.special_fim_pad_id != -1) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, vocab.special_fim_pad_id, vocab.id_to_token.at(vocab.special_fim_pad_id).text.c_str() ); }
-    if (vocab.special_fim_rep_id != -1) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, vocab.special_fim_rep_id, vocab.id_to_token.at(vocab.special_fim_rep_id).text.c_str() ); }
-    if (vocab.special_fim_sep_id != -1) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, vocab.special_fim_sep_id, vocab.id_to_token.at(vocab.special_fim_sep_id).text.c_str() ); }
-
-    if (vocab.special_prefix_id != -1) { LLAMA_LOG_INFO( "%s: PRE token        = %d '%s'\n", __func__, vocab.special_prefix_id, vocab.id_to_token[vocab.special_prefix_id].text.c_str() ); }
-    if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token        = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); }
-    if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token        = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
-    if (vocab.special_eot_id    != -1) { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,    vocab.id_to_token[vocab.special_eot_id].text.c_str() );    }
-
-    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
-
     if (model.arch == LLM_ARCH_DEEPSEEK2) {
         LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
         LLAMA_LOG_INFO("%s: n_lora_q             = %d\n",     __func__, hparams.n_lora_q);
@@ -7170,7 +4442,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
         LLAMA_LOG_INFO("%s: n_ff_shexp       = %d\n",     __func__, hparams.n_ff_shexp);
     }
 
-    if (model.arch == LLM_ARCH_QWEN3MOE) {
+    if (model.arch == LLM_ARCH_QWEN3MOE || model.arch == LLM_ARCH_OPENAI_MOE) {
         LLAMA_LOG_INFO("%s: n_ff_exp         = %d\n",     __func__, hparams.n_ff_exp);
     }
 
@@ -7180,6 +4452,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
         LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
     }
 
+    vocab.print_info();
+
 }
 
 static void llm_prepare_mla(llama_model & model, int mla) {
@@ -9239,7 +6513,7 @@ static bool llm_load_tensors(
                     if (model.output == NULL) {
                         model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
                     }
-                    
+
                     for (int i = 0; i < n_layer; ++i) {
                         ggml_context * ctx_layer = ctx_for_layer(i);
                         ggml_context * ctx_split = ctx_for_layer_split(i);
@@ -9265,9 +6539,9 @@ static bool llm_load_tensors(
                         layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags);
 
                         // K/Q norm tensors (optional for GLM-4.5 355B variant)
-                        layer.attn_q_norm = create_tensor(ctx_layer, 
+                        layer.attn_q_norm = create_tensor(ctx_layer,
                             tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED | flags);
-                        layer.attn_k_norm = create_tensor(ctx_layer, 
+                        layer.attn_k_norm = create_tensor(ctx_layer,
                             tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED | flags);
 
                         layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags);
@@ -9285,21 +6559,21 @@ static bool llm_load_tensors(
                             // MoE branch
                             const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
 
-                            layer.ffn_gate_exps = create_tensor(ctx_split, 
+                            layer.ffn_gate_exps = create_tensor(ctx_split,
                                 tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags);
-                            layer.ffn_down_exps = create_tensor(ctx_split, 
+                            layer.ffn_down_exps = create_tensor(ctx_split,
                                 tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags);
-                            layer.ffn_up_exps = create_tensor(ctx_split, 
+                            layer.ffn_up_exps = create_tensor(ctx_split,
                                 tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags);
 
                             // Shared expert
                             if (n_expert_shared > 0) {
                                 const int64_t n_ff_shexp = n_ff_exp * n_expert_shared;
-                                layer.ffn_gate_shexp     = create_tensor(ctx_split, 
+                                layer.ffn_gate_shexp     = create_tensor(ctx_split,
                                     tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags);
-                                layer.ffn_down_shexp = create_tensor(ctx_split, 
+                                layer.ffn_down_shexp = create_tensor(ctx_split,
                                     tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags);
-                                layer.ffn_up_shexp = create_tensor(ctx_split, 
+                                layer.ffn_up_shexp = create_tensor(ctx_split,
                                     tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags);
                             }
                         } else {
@@ -9794,6 +7068,48 @@ static bool llm_load_tensors(
                         layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0);
                     }
                 } break;
+            case LLM_ARCH_OPENAI_MOE:
+                {
+                    const int64_t n_ff_exp = hparams.n_ff_exp;
+
+                    model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm      = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,      "weight", i), {n_embd}, 0);
+                        layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_head * n_rot}, 0);
+                        layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_head_kv * n_rot}, 0);
+                        layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_head_kv * n_rot}, 0);
+                        layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0);
+
+                        layer.attn_sinks = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0);
+
+                        layer.ffn_gate_inp  = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {  n_embd, n_expert}, 0);
+                        layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                        layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+
+                        // bias
+                        layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_head * n_rot}, 0);
+                        layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_head_kv * n_rot}, 0);
+                        layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_head_kv * n_rot}, 0);
+                        layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
+
+                        layer.ffn_gate_inp_b  = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP,  "bias", i), {n_expert}, 0);
+                        layer.ffn_gate_exps_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0);
+                        layer.ffn_down_exps_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), {  n_embd, n_expert}, 0);
+                        layer.ffn_up_exps_b   = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP_EXPS,   "bias", i), {n_ff_exp, n_expert}, 0);
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -10011,15 +7327,16 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
             throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
         }
         try {
-            llm_load_vocab(ml, model);
+            LLM_KV kv(model.arch);
+            model.vocab.load(ml, kv);
         } catch(const std::exception & e) {
             throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
         }
 
         llm_load_print_meta(ml, model);
 
-        if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
-            model.hparams.n_vocab != model.vocab.id_to_token.size()) {
+        if (model.vocab.get_type() != LLAMA_VOCAB_TYPE_NONE &&
+            model.hparams.n_vocab != model.vocab.n_tokens()) {
             throw std::runtime_error("vocab size mismatch");
         }
 
@@ -10071,6 +7388,7 @@ enum llm_ffn_op_type {
     LLM_FFN_RELU,
     LLM_FFN_RELU_SQR,
     LLM_FFN_SWIGLU,
+    LLM_FFN_SWIGLU_OAI_MOE,
 };
 
 enum llm_ffn_gate_type {
@@ -10362,6 +7680,8 @@ static struct ggml_tensor * llm_build_ffn(
                 cur = ggml_swiglu(ctx, cur);
                 cb(cur, "ffn_swiglu", il);
             } break;
+        default:
+            GGML_ABORT("fatal error");
     }
 
     if (type_gate == LLM_FFN_PAR) {
@@ -10394,15 +7714,15 @@ static struct ggml_tensor * llm_build_ffn(
     return cur;
 }
 
-static struct ggml_tensor * llm_build_moe_ffn(
-        struct ggml_context * ctx,
-       struct llama_context & lctx,
-         struct ggml_tensor * cur,
-         struct ggml_tensor * gate_inp,
-         struct ggml_tensor * up_exps,
-         struct ggml_tensor * gate_exps,
-         struct ggml_tensor * down_exps,
-         struct ggml_tensor * exp_probs_b,
+static ggml_tensor * llm_build_moe_ffn(
+        ggml_context * ctx,
+       llama_context & lctx,
+         ggml_tensor * cur,
+         ggml_tensor * gate_inp,   ggml_tensor * gate_inp_b,
+         ggml_tensor * up_exps,    ggml_tensor * up_exps_b,
+         ggml_tensor * gate_exps,  ggml_tensor * gate_exps_b,
+         ggml_tensor * down_exps,  ggml_tensor * down_exps_b,
+         ggml_tensor * exp_probs_b,
                     int64_t   n_expert,
                     int64_t   n_expert_used,
             llm_ffn_op_type   type_op,
@@ -10411,7 +7731,7 @@ static struct ggml_tensor * llm_build_moe_ffn(
                       float   w_scale,
 llm_expert_gating_func_type   gating_op,
          const llm_build_cb & cb,
-                        int   il) {
+                        int   il, struct ggml_cgraph * graph = nullptr) {
     int64_t n_embd = cur->ne[0];
     int64_t n_tokens = cur->ne[1];
     bool weight_before_ffn = lctx.model.arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
@@ -10419,6 +7739,12 @@ llm_expert_gating_func_type   gating_op,
     ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
     cb(logits, "ffn_moe_logits", il);
 
+    if (gate_inp_b) {
+        logits = ggml_add(ctx, logits, gate_inp_b);
+        cb(logits, "ffn_moe_logits_biased", il);
+    }
+
+
     //ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
     ggml_tensor * probs = nullptr;
     switch (gating_op) {
@@ -10430,6 +7756,10 @@ llm_expert_gating_func_type   gating_op,
             {
                 probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens]
             } break;
+        case LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
+            {
+                probs = logits; // [n_expert, n_tokens]
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -10459,6 +7789,13 @@ llm_expert_gating_func_type   gating_op,
             ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
     cb(weights, "ffn_moe_weights", il);
 
+    if (gating_op == LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
+        weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
+        weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens]
+        weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
+        cb(weights, "ffn_moe_weights_softmax", il);
+    }
+
     if (norm_w) {
         weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
 
@@ -10485,9 +7822,23 @@ llm_expert_gating_func_type   gating_op,
         cb(cur, "ffn_moe_weighted", il);
     }
 
+    // For now we don't modify the fused up/gate op to include biases.
+    // Hence, if we have biases, we cannot use fmoe.
+    //
+    //bool can_use_fmoe = !up_exps_b && !gate_exps_b && (type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU);
+    bool can_use_fmoe = type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU || type_op == LLM_FFN_SWIGLU_OAI_MOE;
+
     ggml_tensor * par;
-    if (lctx.cparams.fused_moe_up_gate && up_exps->type == gate_exps->type) {
-        par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
+    if (can_use_fmoe && lctx.cparams.fused_moe_up_gate && up_exps->type == gate_exps->type) {
+        if (up_exps_b || gate_exps_b) {
+            par = ggml_moe_up_gate_ext(ctx, up_exps, gate_exps, cur, selected_experts, up_exps_b, gate_exps_b,
+                    type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU :
+                    type_op == LLM_FFN_GELU ? GGML_UNARY_OP_GELU : GGML_UNARY_OP_SWIGLU_OAI);
+        } else {
+            GGML_ASSERT(type_op != LLM_FFN_SWIGLU_OAI_MOE);
+            par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts,
+                    type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
+        }
     } else {
         ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
         cb(up, "ffn_moe_up", il);
@@ -10495,41 +7846,49 @@ llm_expert_gating_func_type   gating_op,
         ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
         cb(gate, "ffn_moe_gate", il);
 
-        // This is equivalent to the commented out code below
-        par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
+        if (graph) {
+            // So we can potentially fuse the up and gate mul_mat_id
+            ggml_build_forward_expand(graph, up);
+            ggml_build_forward_expand(graph, gate);
+        }
+
+        if (up_exps_b) {
+            up = ggml_add_id(ctx, up, up_exps_b, selected_experts);
+            cb(up, "ffn_moe_up_biased", il);
+        }
+
+        if (gate_exps_b) {
+            gate = ggml_add_id(ctx, gate, gate_exps_b, selected_experts);
+            cb(gate, "ffn_moe_gate_biased", il);
+        }
+
+        if (type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU) {
+            par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
+        } else if (type_op == LLM_FFN_SWIGLU_OAI_MOE) {
+            constexpr float alpha = 1.702f;
+            constexpr float limit = 7.0f;
+            par = ggml_swiglu_oai(ctx, gate, up, alpha, limit);
+        }
+        else {
+            GGML_ABORT("fatal error");
+        }
+
     }
     cb(par, "ffn_moe_gate_par", il);
 
     ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
     cb(experts, "ffn_moe_down", il);
 
+    if (down_exps_b) {
+        experts = ggml_add_id(ctx, experts, down_exps_b, selected_experts);
+        cb(experts, "ffn_moe_down_biased", il);
+    }
+
     if (!weight_before_ffn) {
         experts = ggml_mul(ctx, experts, weights);
         cb(cur, "ffn_moe_weighted", il);
     }
 
-//#ifdef GGML_USE_VULKAN
-//    // aggregate experts
-//    ggml_tensor * moe_out = nullptr;
-//    //ggml_tensor * first_expert = nullptr;
-//    for (int i = 0; i < n_expert_used; ++i) {
-//        ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
-//                experts->nb[2], i*experts->nb[1]);
-//
-//        if (i == 0) {
-//            moe_out = cur_expert;
-//        } else {
-//            moe_out = ggml_add(ctx, moe_out, cur_expert);
-//        }
-//    }
-//
-//    if (n_expert_used == 1) {
-//        // avoid returning a non-contiguous tensor
-//        moe_out = ggml_cont(ctx, moe_out);
-//    }
-//
-//    return moe_out;
-//#else
     if (n_expert_used == 1) {
         return ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0));
     }
@@ -10538,10 +7897,38 @@ llm_expert_gating_func_type   gating_op,
                              ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], experts->nb[1]));
     }
     return ggml_multi_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0), n_expert_used);
-//#endif
 
 }
 
+static ggml_tensor * llm_build_moe_ffn(
+        struct ggml_context * ctx,
+       struct llama_context & lctx,
+         struct ggml_tensor * cur,
+         struct ggml_tensor * gate_inp,
+         struct ggml_tensor * up_exps,
+         struct ggml_tensor * gate_exps,
+         struct ggml_tensor * down_exps,
+         struct ggml_tensor * exp_probs_b,
+                    int64_t   n_expert,
+                    int64_t   n_expert_used,
+            llm_ffn_op_type   type_op,
+                       bool   norm_w,
+                       bool   scale_w,
+                      float   w_scale,
+llm_expert_gating_func_type   gating_op,
+         const llm_build_cb & cb,
+                        int   il, struct ggml_cgraph * graph = nullptr) {
+    return llm_build_moe_ffn(ctx, lctx, cur,
+            gate_inp,   nullptr,
+            up_exps,    nullptr,
+            gate_exps,  nullptr,
+            down_exps,  nullptr,
+            exp_probs_b,
+            n_expert, n_expert_used,
+            type_op, norm_w, scale_w, w_scale,
+            gating_op, cb, il, graph);
+}
+
 static struct ggml_tensor * llm_build_kqv(
         struct ggml_context * ctx,
        struct llama_context & lctx,
@@ -10555,7 +7942,8 @@ static struct ggml_tensor * llm_build_kqv(
                     int32_t   n_kv,
                     float     kq_scale,
          const llm_build_cb & cb,
-                    int       il) {
+                    int       il,
+                ggml_tensor * sinks = nullptr) {
     const llama_model   & model   = lctx.model;
     const llama_hparams & hparams = lctx.model.hparams;
     const llama_cparams & cparams = lctx.cparams;
@@ -10602,6 +7990,7 @@ static struct ggml_tensor * llm_build_kqv(
 
         cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
                                   hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
+        ggml_flash_attn_ext_add_sinks(cur, sinks);
 
         // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
         // For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG.
@@ -10625,7 +8014,7 @@ static struct ggml_tensor * llm_build_kqv(
         cb(v, "v", il);
 
         auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024);
-        if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2]) {
+        if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2] || sinks) {
             struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
             cb(kq, "kq", il);
 
@@ -10660,6 +8049,7 @@ static struct ggml_tensor * llm_build_kqv(
                         1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
             } else {
                 kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+                ggml_soft_max_add_sinks(kq, sinks);
             }
             cb(kq, "kq_soft_max_ext", il);
 
@@ -10757,7 +8147,8 @@ static struct ggml_tensor * llm_build_kv(
                     int32_t   n_kv,
                     float     kq_scale,
          const llm_build_cb & cb,
-                    int       il) {
+                    int       il,
+                ggml_tensor * sinks = nullptr) {
     const llama_hparams & hparams = lctx.model.hparams;
     const llama_cparams & cparams = lctx.cparams;
 
@@ -10772,7 +8163,7 @@ static struct ggml_tensor * llm_build_kv(
     struct ggml_tensor * cur;
 
     cur  = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b,
-            q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
+            q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks);
     cb(cur, "kqv_out", il);
 
     return cur;
@@ -16406,24 +13797,24 @@ struct llm_build_context {
     struct ggml_cgraph * build_glm4_moe() {
         // create a new graph
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
-    
+
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    
+
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
-    
+
         // input embeddings
         inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
-    
+
         // position embeddings
         struct ggml_tensor * inp_pos = build_inp_pos();
-    
+
         // attention KV cache input
         //auto * inp_attn = build_attn_inp_kv_unified();
 
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-    
+
         // output token IDs (for last layer cropping)
         struct ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -16432,13 +13823,13 @@ struct llm_build_context {
         const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
         for (int il = 0; il < n_transformer_layers; ++il) {
             struct ggml_tensor * inpSA = inpL;
-    
+
             // Pre-attention norm
             cur = llm_build_norm(ctx0, inpL, hparams,
                                  model.layers[il].attn_norm, NULL,
                                  LLM_NORM_RMS, cb, il);
             cb(cur, "attn_norm", il);
-    
+
             // self-attention
             {
                 // Q, K, V projections
@@ -16447,24 +13838,24 @@ struct llm_build_context {
                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                 }
                 cb(Qcur, "Qcur", il);
-    
+
                 struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 if (model.layers[il].bk) {
                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                 }
                 cb(Kcur, "Kcur", il);
-    
+
                 struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 if (model.layers[il].bv) {
                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
                 }
                 cb(Vcur, "Vcur", il);
-    
+
                 // reshape for multi-head
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
                 // Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
-    
+
                 // Apply Q/K norm if available (GLM-4.5 355B variant)
                 if (model.layers[il].attn_q_norm) {
                     Qcur = llm_build_norm(ctx0, Qcur, hparams,
@@ -16478,7 +13869,7 @@ struct llm_build_context {
                                          LLM_NORM_RMS, cb, il);
                     cb(Kcur, "Kcur_normed", il);
                 }
-    
+
                 // apply RoPE
                 Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
                                      n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -16489,7 +13880,7 @@ struct llm_build_context {
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
-    
+
                 // build attention KV (no unified cache)
                 cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                                    model.layers[il].wo, NULL,
@@ -16497,7 +13888,7 @@ struct llm_build_context {
                                    n_tokens, kv_head, n_kv,
                                    1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
-    
+
             // crop output on last layer
             if (il == n_transformer_layers - 1 && inp_out_ids) {
                 // skip computing output for unused tokens
@@ -16505,17 +13896,17 @@ struct llm_build_context {
                 cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
                 inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
             }
-    
+
             // residual connection for attention output
             struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
             cb(ffn_inp, "ffn_inp", il);
-    
+
             // Post-attention norm
             cur = llm_build_norm(ctx0, ffn_inp, hparams,
                                  model.layers[il].attn_post_norm, NULL,
                                  LLM_NORM_RMS, cb, il);
             cb(cur, "post_attn_norm", il);
-    
+
             if ((uint32_t) il < hparams.n_layer_dense_lead) {
                 // dense FFN
                 cur = llm_build_ffn(ctx0, lctx, cur,
@@ -16548,29 +13939,29 @@ struct llm_build_context {
                                                 NULL,
                                                 LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                     cb(shared_out, "ffn_shexp_out", il);
-        
+
                     cur = ggml_add(ctx0, routed_out, shared_out);
                     cb(cur, "ffn_out", il);
                 }
             }
-    
+
             // residual and context vector
             cur = ggml_add(ctx0, cur, ffn_inp);
             cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
-    
+
             // prepare next layer input
             inpL = cur;
         }
 
         cur = inpL;
-    
+
         // final norm
         cur = llm_build_norm(ctx0, cur, hparams,
                              model.output_norm, NULL,
                              LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
-    
+
         // lm head
         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
@@ -17990,6 +15381,126 @@ struct llm_build_context {
 
         return gf;
     }
+
+    struct ggml_cgraph * build_openai_moe() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        struct ggml_tensor * KQ_mask     = build_inp_KQ_mask();
+        struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
+        //const int64_t n_embd_head = hparams.n_embd_head_v;
+        const float kq_scale = 1.0f / sqrtf(float(n_rot)); //float(n_embd_head));
+
+        //auto * inp_attn = build_attn_inp_kv_unified_iswa();
+
+        const int sliding_window_pattern = 2;
+
+        for (int il = 0; il < n_layer; ++il) {
+            const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+            ggml_tensor * inpSA = inpL;
+
+            struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr,
+                                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
+                                    beta_fast, beta_slow);
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
+                                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
+                                    attn_factor, beta_fast, beta_slow);
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks);
+
+                cb(cur, "attn_out", il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            cur = ffn_inp;
+            cur = llm_build_norm(ctx0, cur,  hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_post_norm", il);
+
+            // MoE branch
+            cur = llm_build_moe_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_gate_inp,  model.layers[il].ffn_gate_inp_b,
+                    model.layers[il].ffn_up_exps,   model.layers[il].ffn_up_exps_b,
+                    model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b,
+                    model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b,
+                    nullptr,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SWIGLU_OAI_MOE, false,
+                    false, 0.0,
+                    LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT,
+                    cb, il, gf);
+            cb(cur, "ffn_moe_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, nullptr, LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
 };
 
 static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) {
@@ -18086,9 +15597,9 @@ static struct ggml_cgraph * llama_build_graph(
 
     struct ggml_cgraph * result = NULL;
 
-    const llama_vocab * vocab = llama_get_vocab(&lctx);
-    llama_token bos = llama_token_bos_impl(*vocab);
-    llama_token eos = llama_token_eos_impl(*vocab);
+    const llama_vocab * vocab = &lctx.model.vocab; //llama_get_vocab(&lctx);
+    llama_token bos = vocab->token_bos();
+    llama_token eos = vocab->token_eos();
     bool is_warming_up = lctx.n_eval == 0 && (batch.n_tokens == 1 && (batch.token[0] == ((bos != -1) ? bos : eos)));
     struct llm_build_context llm(lctx, batch, cb, worst_case, is_warming_up);
 
@@ -18297,6 +15808,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_hunyuan_moe();
             } break;
+        case LLM_ARCH_OPENAI_MOE:
+            {
+                result = llm.build_openai_moe();
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -22018,7 +19533,7 @@ uint32_t llama_n_seq_max(const struct llama_context * ctx) {
 }
 
 enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
-    return model->vocab.type;
+    return model->vocab.get_type();
 }
 
 const struct llama_vocab* llama_get_model_vocab(const struct llama_model* model) {
@@ -22089,6 +19604,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_CODESHELL:
         case LLM_ARCH_DOTS1:
         case LLM_ARCH_HUNYUAN_MOE:
+        case LLM_ARCH_OPENAI_MOE:
             return LLAMA_ROPE_TYPE_NEOX;
 
         // all model arches should be listed explicitly here
@@ -22189,7 +19705,7 @@ const char* llama_model_chat_template(const struct llama_model* model, const cha
         // one-off fix for very popular models (so we are not flooded with issues)
         // do not extend this list unless absolutely necessary
         // Mistral-Small-2503 does not have built-in chat template
-        llama_vocab_pre_type pre_type = model->vocab.type_pre;
+        llama_vocab_pre_type pre_type = model->vocab.get_pre_type();
         if (!name && pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
             return "mistral-v7-tekken";
         }
@@ -23307,7 +20823,7 @@ static bool llama_state_load_file_internal(struct llama_context * ctx, const cha
 
     // restore the context state
     {
-        const size_t n_state_size_cur = file.size - file.tell();
+        const size_t n_state_size_cur = file.size() - file.tell();
 
         llama_data_read_file data_ctx(&file);
         const size_t n_read = llama_state_set_data_internal(ctx, data_ctx);
@@ -23444,7 +20960,7 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con
 
     // restore the context state
     {
-        const size_t state_size = file.size - file.tell();
+        const size_t state_size = file.size() - file.tell();
         llama_data_read_file data_ctx(&file);
         const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
         if (!nread) {
@@ -23710,101 +21226,102 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id
 //
 
 const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
-    return llama_token_get_text_impl(model->vocab, token);
+    return model->vocab.token_get_text(token);
 }
 
 float llama_token_get_score(const struct llama_model * model, llama_token token) {
-    return llama_token_get_score_impl(model->vocab, token);
+    return model->vocab.token_get_score(token);
 }
 
 enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
-    return llama_token_get_attr_impl(model->vocab, token);
+    return model->vocab.token_get_attr(token);
 }
 
 bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
-    return llama_token_is_eog_impl(model->vocab, token);
+    return model->vocab.is_eog(token);
 }
 
 bool llama_token_is_control(const struct llama_model * model, llama_token token) {
-    return llama_token_is_control_impl(model->vocab, token);
+    return model->vocab.is_control(token);
 }
 
 llama_token llama_token_bos(const struct llama_model * model) {
-    return llama_token_bos_impl(model->vocab);
+    return model->vocab.token_bos();
 }
 
 llama_token llama_token_eos(const struct llama_model * model) {
-    return llama_token_eos_impl(model->vocab);
+    return model->vocab.token_eos();
 }
 
-llama_token llama_token_cls(const struct llama_model * model) {
-    return llama_token_cls_impl(model->vocab);
-}
+// What is cls?
+//llama_token llama_token_cls(const struct llama_model * model) {
+//    return llama_token_cls_impl(model->vocab);
+//}
 
 llama_token llama_token_sep(const struct llama_model * model) {
-    return llama_token_sep_impl(model->vocab);
+    return model->vocab.token_sep();
 }
 
 llama_token llama_token_nl (const struct llama_model * model) {
-    return llama_token_nl_impl(model->vocab);
+    return model->vocab.token_nl();
 }
 
 llama_token llama_token_pad(const struct llama_model * model) {
-    return llama_token_pad_impl(model->vocab);
+    return model->vocab.token_pad();
 }
 
 int32_t llama_add_bos_token(const struct llama_model * model) {
-    return llama_add_bos_token_impl(model->vocab);
+    return model->vocab.get_add_bos();
 }
 
 int32_t llama_add_eos_token(const struct llama_model * model) {
-    return llama_add_eos_token_impl(model->vocab);
+    return model->vocab.get_add_eos();
 }
 
 llama_token llama_token_prefix(const struct llama_model * model) {
-    return llama_token_prefix_impl(model->vocab);
+    return model->vocab.token_prefix();
 }
 
 llama_token llama_token_middle(const struct llama_model * model) {
-    return llama_token_middle_impl(model->vocab);
+    return model->vocab.token_middle();
 }
 
 llama_token llama_token_suffix(const struct llama_model * model) {
-    return llama_token_suffix_impl(model->vocab);
+    return model->vocab.token_suffix();
 }
 
 llama_token llama_token_eot(const struct llama_model * model) {
-    return llama_token_eot_impl(model->vocab);
+    return model->vocab.token_eot();
 }
 
 // deprecated
 llama_token llama_token_fim_pre(const struct llama_model * model) {
-    return llama_token_fim_pre_impl(model->vocab);
+    return model->vocab.token_fim_pre();
 }
 
 // deprecated
 llama_token llama_token_fim_suf(const struct llama_model * model) {
-    return llama_token_fim_suf_impl(model->vocab);
+    return model->vocab.token_fim_suf();
 }
 
 // deprecated
 llama_token llama_token_fim_mid(const struct llama_model * model) {
-    return llama_token_fim_mid_impl(model->vocab);
+    return model->vocab.token_fim_mid();
 }
 
 // deprecated
 llama_token llama_token_fim_pad(const struct llama_model * model) {
-    return llama_token_fim_pad_impl(model->vocab);
+    return model->vocab.token_fim_pad();
 }
 
 // deprecated
 llama_token llama_token_fim_rep(const struct llama_model * model) {
-    return llama_token_fim_rep_impl(model->vocab);
+    return model->vocab.token_fim_rep();
 }
 
 // deprecated
 llama_token llama_token_fim_sep(const struct llama_model * model) {
-    return llama_token_fim_sep_impl(model->vocab);
+    return model->vocab.token_fim_sep();
 }
 
 //
@@ -23819,7 +21336,7 @@ int32_t llama_tokenize(
                      int32_t   n_tokens_max,
                         bool   add_special,
                         bool   parse_special) {
-    return llama_tokenize_impl(model->vocab, text, text_len, tokens, n_tokens_max, add_special, parse_special);
+    return model->vocab.tokenize(text, text_len, tokens, n_tokens_max, add_special, parse_special);
 }
 
 int32_t llama_token_to_piece(
@@ -23829,7 +21346,7 @@ int32_t llama_token_to_piece(
                      int32_t   length,
                      int32_t   lstrip,
                         bool   special) {
-    return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
+    return model->vocab.token_to_piece(token, buf, length, lstrip, special);
 }
 
 int32_t llama_detokenize(
@@ -23840,7 +21357,7 @@ int32_t llama_detokenize(
                      int32_t   text_len_max,
                         bool   remove_special,
                         bool   unparse_special) {
-    return llama_detokenize_impl(model->vocab, tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
+    return model->vocab.detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
 }
 
 //
@@ -23956,6 +21473,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) {
         return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
     } else if (tmpl_contains("<|im_middle|>") && tmpl_contains("<|im_end|>")) {
         return LLM_CHAT_TEMPLATE_KIMI_K2;
+    } else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) {
+        return LLM_CHAT_TEMPLATE_OPENAI_MOE;
     }
     return LLM_CHAT_TEMPLATE_UNKNOWN;
 }
@@ -24416,6 +21935,16 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|im_assistant|>assistant<|im_middle|>";
         }
+    } else if (tmpl == LLM_CHAT_TEMPLATE_OPENAI_MOE) {
+        // OpenAI MoE (based on Harmony chat template)
+        for (auto message : chat) {
+            std::string role(message->role);
+            ss << "<|start|>" << role << "<|message|>" << message->content;
+            ss << (role == "assistant" ? "<|return|>" : "<|end|>");
+        }
+        if (add_ass) {
+            ss << "<|start|>assistant";
+        }
     } else {
         // template not supported
         return -1;
diff --git a/src/unicode.cpp b/src/unicode.cpp
index c911fd262..65f366517 100644
--- a/src/unicode.cpp
+++ b/src/unicode.cpp
@@ -5,20 +5,19 @@
 #include "unicode.h"
 #include "unicode-data.h"
 
+#include 
 #include 
+#include 
 #include 
 #include 
+#include 
 #include 
 #include 
 #include 
 #include 
 #include 
-#include 
 #include 
 #include 
-#include 
-#include 
-#include 
 
 size_t unicode_len_utf8(char src) {
     const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
@@ -26,7 +25,7 @@ size_t unicode_len_utf8(char src) {
     return lookup[highbits];
 }
 
-static std::string unicode_cpts_to_utf8(const std::vector& cps) {
+static std::string unicode_cpts_to_utf8(const std::vector & cps) {
     std::string result;
     for (size_t i = 0; i < cps.size(); ++i) {
         result.append(unicode_cpt_to_utf8(cps[i]));
@@ -34,7 +33,7 @@ static std::string unicode_cpts_to_utf8(const std::vector& cps) {
     return result;
 }
 
-uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
+uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
     assert(offset < utf8.size());
     if (!(utf8[offset + 0] & 0x80)) {
         auto result = utf8[offset + 0];
@@ -45,7 +44,7 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
         throw std::invalid_argument("invalid character");
     }
     if (!(utf8[offset + 0] & 0x20)) {
-        if (offset + 1 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80)) {
+        if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) {
             throw std::invalid_argument("invalid character");
         }
         auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f);
@@ -53,7 +52,7 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
         return result;
     }
     if (!(utf8[offset + 0] & 0x10)) {
-        if (offset + 2 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80) || !((utf8[offset + 2] & 0xc0) == 0x80)) {
+        if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) {
             throw std::invalid_argument("invalid character");
         }
         auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f);
@@ -61,7 +60,7 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
         return result;
     }
     if (!(utf8[offset + 0] & 0x08)) {
-        if (offset + 3 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80) || !((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) {
+        if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) {
             throw std::invalid_argument("invalid character");
         }
         auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f);
@@ -71,15 +70,15 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
     throw std::invalid_argument("failed to convert utf8 to codepoint");
 }
 
-//static std::vector unicode_cpt_to_utf16(uint32_t cp) {
+//static std::vector unicode_cpt_to_utf16(uint32_t cpt) {
 //    std::vector result;
-//    if (/* 0x0000 <= cp && */ cp <= 0xffff) {
-//        result.emplace_back(cp);
+//    if (/* 0x0000 <= cpt && */ cpt <= 0xffff) {
+//        result.emplace_back(cpt);
 //        return result;
 //    }
-//    if (0x10000 <= cp && cp <= 0x10ffff) {
-//        result.emplace_back(0xd800 | ((cp - 0x10000) >> 10));
-//        result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff));
+//    if (0x10000 <= cpt && cpt <= 0x10ffff) {
+//        result.emplace_back(0xd800 | ((cpt - 0x10000) >> 10));
+//        result.emplace_back(0xdc00 | ((cpt - 0x10000) & 0x03ff));
 //        return result;
 //    }
 //    throw std::invalid_argument("failed to convert codepoint to utf16");
@@ -120,14 +119,14 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
 //    return result;
 //}
 
-static std::vector unicode_cpt_flags_array() {
-    std::vector cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
+static std::vector unicode_cpt_flags_array() {
+    std::vector cpt_flags(MAX_CODEPOINTS, unicode_cpt_flags::UNDEFINED);
 
-    assert(unicode_ranges_flags.front().first == 0);
-    assert(unicode_ranges_flags.back().first == MAX_CODEPOINTS);
+    assert (unicode_ranges_flags.begin()[0].first == 0);
+    assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
     for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
-        const auto range_ini = unicode_ranges_flags[i - 1];  // codepoint_ini, flags
-        const auto range_end = unicode_ranges_flags[i];    // codepoint_end, flags
+        const auto range_ini = unicode_ranges_flags.begin()[i-1];  // codepoint_ini, flags
+        const auto range_end = unicode_ranges_flags.begin()[i];    // codepoint_end, flags
         for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
             cpt_flags[cpt] = range_ini.second;
         }
@@ -145,7 +144,7 @@ static std::vector unicode_cpt_flags_array() {
         cpt_flags[p.second].is_uppercase = true;
     }
 
-    for (auto& range : unicode_ranges_nfd) {  // start, last, nfd
+    for (auto &range : unicode_ranges_nfd) {  // start, last, nfd
         cpt_flags[range.nfd].is_nfd = true;
     }
 
@@ -200,55 +199,38 @@ static std::unordered_map unicode_utf8_to_byte_map() {
     return map;
 }
 
-static inline bool  is_valid_utf8(const std::string& str) {
-    int remaining_bytes = 0; // 当前多字节字符剩余的字节数
-    for (unsigned char c : str) {
-        if (remaining_bytes == 0) {
-            if ((c & 0x80) == 0x00) continue;          // 1字节字符
-            else if ((c & 0xE0) == 0xC0) remaining_bytes = 1; // 2字节
-            else if ((c & 0xF0) == 0xE0) remaining_bytes = 2; // 3字节
-            else if ((c & 0xF8) == 0xF0) remaining_bytes = 3; // 4字节
-            else return false; // 非法起始字节
-        }
-        else {
-            // 检查后续字节是否为10xxxxxx
-            if ((c & 0xC0) != 0x80)
-            {
-                return false;
-            }
-            remaining_bytes--;
-        }
-    }
-    return (remaining_bytes == 0); // 确保多字节字符完整
-}
-
-static inline std::wstring unicode_wstring_from_utf8(const std::string& s) {
+static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
 #if defined(__clang__)
     // disable C++17 deprecation warning for std::codecvt_utf8
 #    pragma clang diagnostic push
 #    pragma clang diagnostic ignored "-Wdeprecated-declarations"
+#elif defined(__GNUC__)
+#    pragma GCC diagnostic push
+#    pragma GCC diagnostic ignored "-Wdeprecated-declarations"
 #endif
-    bool isvalid = is_valid_utf8(s);
+
     std::wstring_convert> conv;
 
 #if defined(__clang__)
 #    pragma clang diagnostic pop
+#elif defined(__GNUC__)
+#    pragma GCC diagnostic pop
 #endif
 
     return conv.from_bytes(s);
 }
 
-static std::vector unicode_byte_encoding_process(const std::vector& bpe_words) {
+static std::vector unicode_byte_encoding_process(const std::vector & bpe_words) {
     std::vector bpe_encoded_words;
-    for (const auto& word : bpe_words) {
+    for (const auto & word : bpe_words) {
         std::string text_utf;
-        auto utf_word = unicode_cpts_from_utf8(word);
+        auto utf_word =  unicode_cpts_from_utf8(word);
         for (size_t i = 0; i < utf_word.size(); ++i) {
             text_utf += unicode_cpt_to_utf8(utf_word[i]);
         }
 
         std::string encoded_token;
-        for (char& c : text_utf) {
+        for (char & c : text_utf) {
             encoded_token += unicode_byte_to_utf8(c);
         }
         bpe_encoded_words.emplace_back(encoded_token);
@@ -257,7 +239,7 @@ static std::vector unicode_byte_encoding_process(const std::vector<
 }
 
 // GPT2 system regex:  's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
-static std::vector unicode_regex_split_custom_gpt2(const std::string& text, const std::vector& offsets) {
+static std::vector unicode_regex_split_custom_gpt2(const std::string & text, const std::vector & offsets) {
     std::vector bpe_offsets; // store the offset of each word
     bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
 
@@ -271,16 +253,16 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string& te
         start = offset_end;
 
         static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
-        auto _get_cpt = [&](const size_t pos) -> uint32_t {
+        auto _get_cpt = [&] (const size_t pos) -> uint32_t {
             return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
         };
 
-        auto _get_flags = [&](const size_t pos) -> codepoint_flags {
-            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
+        auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
         };
 
         size_t _prev_end = offset_ini;
-        auto _add_token = [&](const size_t end) -> size_t {
+        auto _add_token = [&] (const size_t end) -> size_t {
             assert(_prev_end <= end && end <= offset_end);
             size_t len = end - _prev_end;
             if (len > 0) {
@@ -296,29 +278,29 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string& te
             return len;
         };
 
-        for (size_t pos = offset_ini; pos < offset_end; /*pos++*/) {
+        for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
             const uint32_t cpt = _get_cpt(pos);
             const auto flags = _get_flags(pos);
 
             // regex: 's|'t|'re|'ve|'m|'ll|'d
-            if (cpt == '\'' && pos + 1 < offset_end) {
-                uint32_t cpt_next = _get_cpt(pos + 1);
+            if (cpt == '\'' && pos+1 < offset_end) {
+                uint32_t cpt_next = _get_cpt(pos+1);
                 if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
-                    pos += _add_token(pos + 2);
+                    pos += _add_token(pos+2);
                     continue;
                 }
-                if (pos + 2 < offset_end) {
-                    uint32_t cpt_next_next = _get_cpt(pos + 2);
+                if (pos+2 < offset_end) {
+                    uint32_t cpt_next_next = _get_cpt(pos+2);
                     if ((cpt_next == 'r' && cpt_next_next == 'e') ||
                         (cpt_next == 'v' && cpt_next_next == 'e') ||
                         (cpt_next == 'l' && cpt_next_next == 'l')) {
-                        pos += _add_token(pos + 3);
+                        pos += _add_token(pos+3);
                         continue;
                     }
                 }
             }
 
-            auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags);
+            auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
             // regex: ?\p{L}+
             if (flags2.is_letter) {
                 pos += (cpt == ' ');
@@ -348,12 +330,12 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string& te
             }
 
             size_t num_whitespaces = 0;
-            while (_get_flags(pos + num_whitespaces).is_whitespace) {
+            while (_get_flags(pos+num_whitespaces).is_whitespace) {
                 num_whitespaces++;
             }
 
             // regex: \s+(?!\S)
-            if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) {
+            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
                 pos += num_whitespaces - 1;
                 _add_token(pos);
                 continue;
@@ -374,11 +356,10 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string& te
     return bpe_offsets;
 }
 
-// K2 system regex patterns (from tokenization_kimi.py):
-// [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+
-static std::vector unicode_regex_split_custom_kimi_k2(const std::string & text, const std::vector & offsets) {
-    std::vector bpe_offsets;
-    bpe_offsets.reserve(offsets.size());
+// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
+static std::vector unicode_regex_split_custom_llama3(const std::string & text, const std::vector & offsets) {
+    std::vector bpe_offsets; // store the offset of each word
+    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
 
     const auto cpts = unicode_cpts_from_utf8(text);
 
@@ -394,8 +375,8 @@ static std::vector unicode_regex_split_custom_kimi_k2(const std::string
             return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
         };
 
-        auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
-            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
+        auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
         };
 
         size_t _prev_end = offset_ini;
@@ -406,6 +387,12 @@ static std::vector unicode_regex_split_custom_kimi_k2(const std::string
                 bpe_offsets.push_back(len);
             }
             _prev_end = end;
+            //if (len > 0) {
+            //    std::string s = "";
+            //    for(size_t p = end-len; p < end; p++)
+            //        s += unicode_cpt_to_utf8(cpts[p]);
+            //    printf(">>> '%s'\n", s.c_str());
+            //}
             return len;
         };
 
@@ -413,75 +400,41 @@ static std::vector unicode_regex_split_custom_kimi_k2(const std::string
             const uint32_t cpt = _get_cpt(pos);
             const auto flags = _get_flags(pos);
 
-            // Pattern 1: [\p{Han}]+ (Chinese characters)
-            if (unicode_cpt_is_han(cpt)) {
-                while (unicode_cpt_is_han(_get_cpt(pos))) {
-                    pos++;
+            // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
+            if (cpt == '\'' && pos+1 < offset_end) {
+                uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
+                if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
+                    pos += _add_token(pos+2);
+                    continue;
                 }
-                _add_token(pos);
-                continue;
-            }
-
-            // Pattern 2 & 3: Letter words excluding Han characters with optional contractions
-            // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?:'s|'t|'re|'ve|'m|'ll|'d)?
-            // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?:'s|'t|'re|'ve|'m|'ll|'d)?
-            // Check if current char is a letter OR if current char could be a leading char and next char is a letter
-            bool is_letter_pattern = (flags.is_letter && !unicode_cpt_is_han(cpt)) ||
-                                     (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number) &&
-                                      _get_flags(pos + 1).is_letter && !unicode_cpt_is_han(_get_cpt(pos + 1)));
-
-            if (is_letter_pattern) {
-                // Handle optional leading non-letter/non-number character
-                bool has_leading_char = false;
-                if (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number)) {
-                    has_leading_char = true;
-                    pos++;
+                if (pos+2 < offset_end) {
+                    uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
+                    if ((cpt_next == 'r' && cpt_next_next == 'e') ||
+                        (cpt_next == 'v' && cpt_next_next == 'e') ||
+                        (cpt_next == 'l' && cpt_next_next == 'l')) {
+                        pos += _add_token(pos+3);
+                        continue;
+                    }
                 }
+            }
 
-                // Match letter sequence (excluding Han characters)
-                bool has_letters = false;
-                while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
-                    has_letters = true;
+            // regex: [^\r\n\p{L}\p{N}]?\p{L}+
+            if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
+                if (flags.is_letter || _get_flags(pos+1).is_letter) {  // one or more letters
                     pos++;
-                }
-
-                // Only proceed if we found letters (after potentially skipping leading char)
-                if (has_letters || (!has_leading_char && _get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos)))) {
-                    if (!has_letters) pos++; // consume the first letter if we didn't already
-
-                    // Continue consuming letters
-                    while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
+                    while (_get_flags(pos).is_letter) {
                         pos++;
                     }
-
-                    // Check for optional contractions (?:'s|'t|'re|'ve|'m|'ll|'d)
-                    if (_get_cpt(pos) == '\'' && pos + 1 < offset_end) {
-                        uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1));
-                        if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
-                            pos += 2;
-                        } else if (pos + 2 < offset_end) {
-                            uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2));
-                            if ((cpt_next == 'r' && cpt_next_next == 'e') ||
-                                (cpt_next == 'v' && cpt_next_next == 'e') ||
-                                (cpt_next == 'l' && cpt_next_next == 'l')) {
-                                pos += 3;
-                            }
-                        }
-                    }
-
                     _add_token(pos);
                     continue;
-                } else if (has_leading_char) {
-                    // We consumed a leading char but found no letters, backtrack
-                    pos--;
                 }
             }
 
-            // Pattern 4: \p{N}{1,3} (numbers 1-3 digits)
+            // regex: \p{N}{1,3}
             if (flags.is_number) {
                 size_t ini = pos;
                 while (_get_flags(pos).is_number) {
-                    if (++pos - ini >= 3) {
+                    if (++pos - ini >= 3 ) {
                         _add_token(pos);
                         ini = pos;
                     }
@@ -490,14 +443,13 @@ static std::vector unicode_regex_split_custom_kimi_k2(const std::string
                 continue;
             }
 
-            // Pattern 5:  ?[^\s\p{L}\p{N}]+[\r\n]* (optional space + non-word chars + optional newlines)
-            auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags);
-            if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
+            // regex: ?[^\s\p{L}\p{N}]+[\r\n]*
+            auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
+            if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
                 pos += (cpt == ' ');
-                while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
+                while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
                     flags2 = _get_flags(++pos);
                 }
-                // Match optional [\r\n]*
                 uint32_t cpt2 = _get_cpt(pos);
                 while (cpt2 == '\r' || cpt2 == '\n') {
                     cpt2 = _get_cpt(++pos);
@@ -506,39 +458,38 @@ static std::vector unicode_regex_split_custom_kimi_k2(const std::string
                 continue;
             }
 
-            // Count whitespace characters
             size_t num_whitespaces = 0;
             size_t last_end_r_or_n = 0;
-            while (_get_flags(pos + num_whitespaces).is_whitespace) {
-                uint32_t cpt2 = _get_cpt(pos + num_whitespaces);
+            while (_get_flags(pos+num_whitespaces).is_whitespace) {
+                uint32_t cpt2 = _get_cpt(pos+num_whitespaces);
                 if (cpt2 == '\r' || cpt2 == '\n') {
                     last_end_r_or_n = pos + num_whitespaces + 1;
                 }
                 num_whitespaces++;
             }
 
-            // Pattern 6: \s*[\r\n]+ (whitespace with newlines)
+            // regex: \s*[\r\n]+
             if (last_end_r_or_n > 0) {
                 pos = last_end_r_or_n;
                 _add_token(pos);
                 continue;
             }
 
-            // Pattern 7: \s+(?!\S) (trailing whitespace)
-            if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) {
+            // regex: \s+(?!\S)
+            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
                 pos += num_whitespaces - 1;
                 _add_token(pos);
                 continue;
             }
 
-            // Pattern 8: \s+ (general whitespace)
+            // regex: \s+
             if (num_whitespaces > 0) {
                 pos += num_whitespaces;
                 _add_token(pos);
                 continue;
             }
 
-            // No matches - consume single character
+            // no matches
             _add_token(++pos);
         }
     }
@@ -546,10 +497,71 @@ static std::vector unicode_regex_split_custom_kimi_k2(const std::string
     return bpe_offsets;
 }
 
-// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
-static std::vector unicode_regex_split_custom_llama3(const std::string& text, const std::vector& offsets) {
+// use std::wregex to split the text
+static std::vector unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector & offsets) {
+    std::wregex expr(regex_expr);
+    std::vector bpe_offsets; // store the offset of each word
+    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
+    size_t start = 0;
+    for (auto offset : offsets) {
+        std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
+        std::wcregex_iterator end;
+
+        int64_t start_idx = 0;
+        while (it != end) {
+            std::wcmatch match = *it;
+            if (match.position() > start_idx) {
+                bpe_offsets.emplace_back(match.position() - start_idx);
+            }
+            bpe_offsets.emplace_back(match.length());
+            start_idx = match.position() + match.length();
+            ++it;
+        }
+
+        if (start_idx < (int64_t) offset) {
+            bpe_offsets.emplace_back(offset - start_idx);
+        }
+        start += offset;
+    }
+
+    return bpe_offsets;
+}
+
+// use std::regex to split the text
+static std::vector unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector & offsets) {
+    std::regex expr(regex_expr);
     std::vector bpe_offsets; // store the offset of each word
     bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
+    size_t start = 0;
+    for (auto offset : offsets) {
+        std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
+        std::cregex_iterator end;
+
+        int64_t start_idx = 0;
+        while (it != end) {
+            std::cmatch match = *it;
+            if (match.position() > start_idx) {
+                bpe_offsets.emplace_back(match.position() - start_idx);
+            }
+            bpe_offsets.emplace_back(match.length());
+            start_idx = match.position() + match.length();
+            ++it;
+        }
+
+        if (start_idx < (int64_t) offset) {
+            bpe_offsets.emplace_back(offset - start_idx);
+        }
+        start += offset;
+    }
+
+    return bpe_offsets;
+}
+
+// K2 system regex patterns (from tokenization_kimi.py):
+// [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+
+static std::vector unicode_regex_split_custom_kimi_k2(const std::string & text, const std::vector & offsets) {
+    std::vector bpe_offsets;
+    bpe_offsets.reserve(offsets.size());
 
     const auto cpts = unicode_cpts_from_utf8(text);
 
@@ -561,66 +573,94 @@ static std::vector unicode_regex_split_custom_llama3(const std::string&
         start = offset_end;
 
         static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
-        auto _get_cpt = [&](const size_t pos) -> uint32_t {
+        auto _get_cpt = [&] (const size_t pos) -> uint32_t {
             return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
         };
 
-        auto _get_flags = [&](const size_t pos) -> codepoint_flags {
-            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
+        auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
         };
 
         size_t _prev_end = offset_ini;
-        auto _add_token = [&](const size_t end) -> size_t {
+        auto _add_token = [&] (const size_t end) -> size_t {
             assert(_prev_end <= end && end <= offset_end);
             size_t len = end - _prev_end;
             if (len > 0) {
                 bpe_offsets.push_back(len);
             }
             _prev_end = end;
-            //if (len > 0) {
-            //    std::string s = "";
-            //    for(size_t p = end-len; p < end; p++)
-            //        s += unicode_cpt_to_utf8(cpts[p]);
-            //    printf(">>> '%s'\n", s.c_str());
-            //}
             return len;
         };
 
-        for (size_t pos = offset_ini; pos < offset_end; /*pos++*/) {
+        for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
             const uint32_t cpt = _get_cpt(pos);
             const auto flags = _get_flags(pos);
 
-            // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
-            if (cpt == '\'' && pos + 1 < offset_end) {
-                uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1));
-                if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
-                    pos += _add_token(pos + 2);
-                    continue;
-                }
-                if (pos + 2 < offset_end) {
-                    uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2));
-                    if ((cpt_next == 'r' && cpt_next_next == 'e') ||
-                        (cpt_next == 'v' && cpt_next_next == 'e') ||
-                        (cpt_next == 'l' && cpt_next_next == 'l')) {
-                        pos += _add_token(pos + 3);
-                        continue;
-                    }
+            // Pattern 1: [\p{Han}]+ (Chinese characters)
+            if (unicode_cpt_is_han(cpt)) {
+                while (unicode_cpt_is_han(_get_cpt(pos))) {
+                    pos++;
                 }
+                _add_token(pos);
+                continue;
             }
 
-            // regex: [^\r\n\p{L}\p{N}]?\p{L}+
-            if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
-                if (flags.is_letter || _get_flags(pos + 1).is_letter) {  // one or more letters
+            // Pattern 2 & 3: Letter words excluding Han characters with optional contractions
+            // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?:'s|'t|'re|'ve|'m|'ll|'d)?
+            // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?:'s|'t|'re|'ve|'m|'ll|'d)?
+            // Check if current char is a letter OR if current char could be a leading char and next char is a letter
+            bool is_letter_pattern = (flags.is_letter && !unicode_cpt_is_han(cpt)) ||
+                                     (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number) &&
+                                      _get_flags(pos + 1).is_letter && !unicode_cpt_is_han(_get_cpt(pos + 1)));
+
+            if (is_letter_pattern) {
+                // Handle optional leading non-letter/non-number character
+                bool has_leading_char = false;
+                if (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number)) {
+                    has_leading_char = true;
                     pos++;
-                    while (_get_flags(pos).is_letter) {
+                }
+
+                // Match letter sequence (excluding Han characters)
+                bool has_letters = false;
+                while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
+                    has_letters = true;
+                    pos++;
+                }
+
+                // Only proceed if we found letters (after potentially skipping leading char)
+                if (has_letters || (!has_leading_char && _get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos)))) {
+                    if (!has_letters) pos++; // consume the first letter if we didn't already
+
+                    // Continue consuming letters
+                    while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
                         pos++;
                     }
+
+                    // Check for optional contractions (?:'s|'t|'re|'ve|'m|'ll|'d)
+                    if (_get_cpt(pos) == '\'' && pos + 1 < offset_end) {
+                        uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1));
+                        if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
+                            pos += 2;
+                        } else if (pos + 2 < offset_end) {
+                            uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2));
+                            if ((cpt_next == 'r' && cpt_next_next == 'e') ||
+                                (cpt_next == 'v' && cpt_next_next == 'e') ||
+                                (cpt_next == 'l' && cpt_next_next == 'l')) {
+                                pos += 3;
+                            }
+                        }
+                    }
+
                     _add_token(pos);
                     continue;
+                } else if (has_leading_char) {
+                    // We consumed a leading char but found no letters, backtrack
+                    pos--;
                 }
             }
 
-            // regex: \p{N}{1,3}
+            // Pattern 4: \p{N}{1,3} (numbers 1-3 digits)
             if (flags.is_number) {
                 size_t ini = pos;
                 while (_get_flags(pos).is_number) {
@@ -633,13 +673,14 @@ static std::vector unicode_regex_split_custom_llama3(const std::string&
                 continue;
             }
 
-            // regex: ?[^\s\p{L}\p{N}]+[\r\n]*
+            // Pattern 5:  ?[^\s\p{L}\p{N}]+[\r\n]* (optional space + non-word chars + optional newlines)
             auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags);
-            if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
+            if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
                 pos += (cpt == ' ');
-                while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
+                while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
                     flags2 = _get_flags(++pos);
                 }
+                // Match optional [\r\n]*
                 uint32_t cpt2 = _get_cpt(pos);
                 while (cpt2 == '\r' || cpt2 == '\n') {
                     cpt2 = _get_cpt(++pos);
@@ -648,6 +689,7 @@ static std::vector unicode_regex_split_custom_llama3(const std::string&
                 continue;
             }
 
+            // Count whitespace characters
             size_t num_whitespaces = 0;
             size_t last_end_r_or_n = 0;
             while (_get_flags(pos + num_whitespaces).is_whitespace) {
@@ -658,28 +700,28 @@ static std::vector unicode_regex_split_custom_llama3(const std::string&
                 num_whitespaces++;
             }
 
-            // regex: \s*[\r\n]+
+            // Pattern 6: \s*[\r\n]+ (whitespace with newlines)
             if (last_end_r_or_n > 0) {
                 pos = last_end_r_or_n;
                 _add_token(pos);
                 continue;
             }
 
-            // regex: \s+(?!\S)
+            // Pattern 7: \s+(?!\S) (trailing whitespace)
             if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) {
                 pos += num_whitespaces - 1;
                 _add_token(pos);
                 continue;
             }
 
-            // regex: \s+
+            // Pattern 8: \s+ (general whitespace)
             if (num_whitespaces > 0) {
                 pos += num_whitespaces;
                 _add_token(pos);
                 continue;
             }
 
-            // no matches
+            // No matches - consume single character
             _add_token(++pos);
         }
     }
@@ -687,79 +729,17 @@ static std::vector unicode_regex_split_custom_llama3(const std::string&
     return bpe_offsets;
 }
 
-// use std::wregex to split the text
-static std::vector unicode_regex_split_stl(const std::wstring& wtext, const std::wstring& regex_expr, const std::vector& offsets) {
-    std::wregex expr(regex_expr);
-    std::vector bpe_offsets; // store the offset of each word
-    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
-    size_t start = 0;
-    for (auto offset : offsets) {
-        std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
-        std::wcregex_iterator end;
-
-        int64_t start_idx = 0;
-        while (it != end) {
-            std::wcmatch match = *it;
-            if (match.position() > start_idx) {
-                bpe_offsets.emplace_back(match.position() - start_idx);
-            }
-            bpe_offsets.emplace_back(match.length());
-            start_idx = match.position() + match.length();
-            ++it;
-        }
-
-        if (start_idx < (int64_t)offset) {
-            bpe_offsets.emplace_back(offset - start_idx);
-        }
-        start += offset;
-    }
-
-    return bpe_offsets;
-}
-
-// use std::regex to split the text
-static std::vector unicode_regex_split_stl(const std::string& text, const std::string& regex_expr, const std::vector& offsets) {
-    std::regex expr(regex_expr);
-    std::vector bpe_offsets; // store the offset of each word
-    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
-    size_t start = 0;
-    for (auto offset : offsets) {
-        std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
-        std::cregex_iterator end;
-
-        int64_t start_idx = 0;
-        while (it != end) {
-            std::cmatch match = *it;
-            if (match.position() > start_idx) {
-                bpe_offsets.emplace_back(match.position() - start_idx);
-            }
-            bpe_offsets.emplace_back(match.length());
-            start_idx = match.position() + match.length();
-            ++it;
-        }
-
-        if (start_idx < (int64_t)offset) {
-            bpe_offsets.emplace_back(offset - start_idx);
-        }
-        start += offset;
-    }
-
-    return bpe_offsets;
-}
-
-static std::vector unicode_regex_split_custom(const std::string& text, const std::string& regex_expr, const std::vector& offsets) {
+static std::vector unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector & offsets) {
     std::vector bpe_offsets;
 
     if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
         bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
-    }
-    else if (
-        regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
-        regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
+    } else if (
+            regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
+            regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
 
         bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
-    }
-    else if (regex_expr == "\\p{Han}+") {
+    } else if (regex_expr == "\\p{Han}+") {
         // K2's first pattern - handle all K2 patterns together
         bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);
     }
@@ -771,71 +751,100 @@ static std::vector unicode_regex_split_custom(const std::string& text, c
 // interface
 //
 
-std::string unicode_cpt_to_utf8(uint32_t cp) {
+std::string unicode_cpt_to_utf8(uint32_t cpt) {
     std::string result;
 
-    if (/* 0x00 <= cp && */ cp <= 0x7f) {
-        result.push_back(cp);
+    if (/* 0x00 <= cpt && */ cpt <= 0x7f) {
+        result.push_back(cpt);
         return result;
     }
-    if (0x80 <= cp && cp <= 0x7ff) {
-        result.push_back(0xc0 | ((cp >> 6) & 0x1f));
-        result.push_back(0x80 | (cp & 0x3f));
+    if (0x80 <= cpt && cpt <= 0x7ff) {
+        result.push_back(0xc0 | ((cpt >> 6) & 0x1f));
+        result.push_back(0x80 | (cpt & 0x3f));
         return result;
     }
-    if (0x800 <= cp && cp <= 0xffff) {
-        result.push_back(0xe0 | ((cp >> 12) & 0x0f));
-        result.push_back(0x80 | ((cp >> 6) & 0x3f));
-        result.push_back(0x80 | (cp & 0x3f));
+    if (0x800 <= cpt && cpt <= 0xffff) {
+        result.push_back(0xe0 | ((cpt >> 12) & 0x0f));
+        result.push_back(0x80 | ((cpt >> 6) & 0x3f));
+        result.push_back(0x80 | (cpt & 0x3f));
         return result;
     }
-    if (0x10000 <= cp && cp <= 0x10ffff) {
-        result.push_back(0xf0 | ((cp >> 18) & 0x07));
-        result.push_back(0x80 | ((cp >> 12) & 0x3f));
-        result.push_back(0x80 | ((cp >> 6) & 0x3f));
-        result.push_back(0x80 | (cp & 0x3f));
+    if (0x10000 <= cpt && cpt <= 0x10ffff) {
+        result.push_back(0xf0 | ((cpt >> 18) & 0x07));
+        result.push_back(0x80 | ((cpt >> 12) & 0x3f));
+        result.push_back(0x80 | ((cpt >> 6) & 0x3f));
+        result.push_back(0x80 | (cpt & 0x3f));
         return result;
     }
 
     throw std::invalid_argument("invalid codepoint");
 }
 
-std::vector unicode_cpts_normalize_nfd(const std::vector& cpts) {
-    auto comp = [](const uint32_t cpt, const range_nfd& range) {
+std::vector unicode_cpts_normalize_nfd(const std::vector & cpts) {
+    auto comp = [] (const uint32_t cpt, const range_nfd & range) {
         return cpt < range.first;
     };
     std::vector result(cpts.size());
     for (size_t i = 0; i < cpts.size(); ++i) {
         const uint32_t cpt = cpts[i];
-        auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1;
+        auto it = std::upper_bound(unicode_ranges_nfd.begin(), unicode_ranges_nfd.end(), cpt, comp) - 1;
         result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
     }
     return result;
 }
 
-std::vector unicode_cpts_from_utf8(const std::string& utf8) {
+std::vector unicode_cpts_from_utf8(const std::string & utf8) {
     std::vector result;
     result.reserve(utf8.size());
     size_t offset = 0;
     while (offset < utf8.size()) {
-        result.push_back(unicode_cpt_from_utf8(utf8, offset));
+        try {
+            result.push_back(unicode_cpt_from_utf8(utf8, offset));
+        }
+        catch (const std::invalid_argument & /*ex*/) {
+            // Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
+            ++offset;
+            result.emplace_back(0xFFFD); // replacement character
+        }
     }
     return result;
 }
 
-codepoint_flags unicode_cpt_flags(const uint32_t cp) {
-    static const codepoint_flags undef(codepoint_flags::UNDEFINED);
+unicode_cpt_flags unicode_cpt_flags_from_cpt(const uint32_t cpt) {
+    static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
     static const auto cpt_flags = unicode_cpt_flags_array();
-    return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
+    return cpt < cpt_flags.size() ? cpt_flags[cpt] : undef;
 }
 
-codepoint_flags unicode_cpt_flags(const std::string& utf8) {
-    static const codepoint_flags undef(codepoint_flags::UNDEFINED);
+unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8) {
+    static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
     if (utf8.empty()) {
         return undef;  // undefined
     }
     size_t offset = 0;
-    return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
+    return unicode_cpt_flags_from_cpt(unicode_cpt_from_utf8(utf8, offset));
+}
+
+std::string unicode_byte_to_utf8(uint8_t byte) {
+    static std::unordered_map map = unicode_byte_to_utf8_map();
+    return map.at(byte);
+}
+
+uint8_t unicode_utf8_to_byte(const std::string & utf8) {
+    static std::unordered_map map = unicode_utf8_to_byte_map();
+    return map.at(utf8);
+}
+
+uint32_t unicode_tolower(uint32_t cpt) {
+    // binary search
+    auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cpt,
+        [](const std::pair & pair, uint32_t value) {
+            return pair.first < value;
+        });
+    if (it != unicode_map_lowercase.end() && it->first == cpt) {
+        return it->second;
+    }
+    return cpt;  // Return the original code point if no lowercase mapping is found
 }
 
 bool unicode_cpt_is_han(uint32_t cpt) {
@@ -870,53 +879,37 @@ bool unicode_cpt_is_han(uint32_t cpt) {
     return false;
 }
 
-std::string unicode_byte_to_utf8(uint8_t byte) {
-    static std::unordered_map map = unicode_byte_to_utf8_map();
-    return map.at(byte);
-}
-
-uint8_t unicode_utf8_to_byte(const std::string& utf8) {
-    static std::unordered_map map = unicode_utf8_to_byte_map();
-    return map.at(utf8);
-}
-
-uint32_t unicode_tolower(uint32_t cp) {
-    auto it = unicode_map_lowercase.find(cp);
-    return it == unicode_map_lowercase.end() ? cp : it->second;
-}
-
-std::vector unicode_regex_split(const std::string& text, const std::vector& regex_exprs) {
+std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) {
     // unicode categories
     static const std::map k_ucat_enum = {
-        { "\\p{N}", codepoint_flags::NUMBER },
-        { "\\p{L}", codepoint_flags::LETTER },
-        { "\\p{P}", codepoint_flags::PUNCTUATION },
-        { "\\p{M}", codepoint_flags::ACCENT_MARK },
-        { "\\p{S}", codepoint_flags::SYMBOL },
+        { "\\p{N}", unicode_cpt_flags::NUMBER },
+        { "\\p{L}", unicode_cpt_flags::LETTER },
+        { "\\p{P}", unicode_cpt_flags::PUNCTUATION },
+        { "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
+        { "\\p{S}", unicode_cpt_flags::SYMBOL },
     };
 
     static const std::map k_ucat_cpt = {
-        { codepoint_flags::NUMBER,        0xD1 },
-        { codepoint_flags::LETTER,        0xD2 },
-        { codepoint_flags::PUNCTUATION,   0xD3 },
-        { codepoint_flags::ACCENT_MARK,   0xD4 },
-        { codepoint_flags::SYMBOL,        0xD5 },
-
+        { unicode_cpt_flags::NUMBER,      0xD1 },
+        { unicode_cpt_flags::LETTER,      0xD2 },
+        { unicode_cpt_flags::PUNCTUATION, 0xD3 },
+        { unicode_cpt_flags::ACCENT_MARK, 0xD4 },
+        { unicode_cpt_flags::SYMBOL,      0xD5 },
     };
 
     static const std::map k_ucat_map = {
-        { codepoint_flags::NUMBER,        "\x30-\x39" }, // 0-9
-        { codepoint_flags::LETTER,        "\x41-\x5A\x61-\x7A" }, // A-Za-z
-        { codepoint_flags::PUNCTUATION,   "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}i
-        { codepoint_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
-        { codepoint_flags::SYMBOL,      "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
+        { unicode_cpt_flags::NUMBER,      "\x30-\x39" }, // 0-9
+        { unicode_cpt_flags::LETTER,      "\x41-\x5A\x61-\x7A" }, // A-Za-z
+        { unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
+        { unicode_cpt_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
+        { unicode_cpt_flags::SYMBOL,      "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
     };
 
     // compute collapsed codepoints only if needed by at least one regex
     bool need_collapse = false;
-    for (auto& regex_expr : regex_exprs) {
+    for (const auto & regex_expr : regex_exprs) {
         // search for unicode categories
-        for (const auto& ucat : k_ucat_enum) {
+        for (const auto & ucat : k_ucat_enum) {
             if (std::string::npos != regex_expr.find(ucat.first)) {
                 need_collapse = true;
                 break;
@@ -927,7 +920,7 @@ std::vector unicode_regex_split(const std::string& text, const std:
     const auto cpts = unicode_cpts_from_utf8(text);
 
     // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
-    // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
+    // ref: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2081479935
     std::string text_collapsed;
     if (need_collapse) {
         // collapse all unicode categories
@@ -940,25 +933,23 @@ std::vector unicode_regex_split(const std::string& text, const std:
                 continue;
             }
 
-            const auto flags = unicode_cpt_flags(cpts[i]);
+            const auto flags = unicode_cpt_flags_from_cpt(cpts[i]);
 
             if (flags.is_whitespace) {
                 //NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
                 //text_collapsed[i] = (char) 0x85;  //  as whitespace fallback
-                text_collapsed[i] = (char)0x0B;    //  as whitespace fallback
-            }
-            else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
+                text_collapsed[i] = (char) 0x0B;    //  as whitespace fallback
+            } else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
                 text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
-            }
-            else {
-                text_collapsed[i] = (char)0xD0; // fallback
+            } else {
+                text_collapsed[i] = (char) 0xD0; // fallback
             }
         }
     }
 
     std::vector bpe_offsets = { cpts.size() };
 
-    for (auto& regex_expr : regex_exprs) {
+    for (const auto & regex_expr : regex_exprs) {
         // first, see if we have an efficient custom regex implementation
         auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
 
@@ -972,7 +963,7 @@ std::vector unicode_regex_split(const std::string& text, const std:
             // if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
             // with the corresponding collapsed representation
             bool use_collapsed = false;
-            for (auto& ucat : k_ucat_enum) {
+            for (const auto & ucat : k_ucat_enum) {
                 if (std::string::npos != regex_expr.find(ucat.first)) {
                     use_collapsed = true;
                     break;
@@ -1031,15 +1022,14 @@ std::vector unicode_regex_split(const std::string& text, const std:
                 //printf("text_collapsed: %s\n", text_collapsed.c_str());
                 //printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
                 bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
-            }
-            else {
+            } else {
                 // no unicode category used, we can use std::wregex directly
                 const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
 
                 // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
                 std::wstring wtext(cpts.begin(), cpts.end());
                 for (size_t i = 0; i < wtext.size(); ++i) {
-                    if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
+                    if (wtext[i] > 0x7F && unicode_cpt_flags_from_cpt(wtext[i]).is_whitespace) {
                         wtext[i] = 0x0B;
                     }
                 }
@@ -1048,8 +1038,7 @@ std::vector unicode_regex_split(const std::string& text, const std:
                 //printf("regex_expr: %s\n", regex_expr.c_str());
                 bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
             }
-        }
-        catch (std::regex_error& e) {
+        } catch (std::regex_error & e) {
             fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
             fprintf(stderr, "Regex error: %s\n", e.what());
             throw std::runtime_error("Failed to process regex");
@@ -1060,7 +1049,7 @@ std::vector unicode_regex_split(const std::string& text, const std:
     bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size
 
     size_t start = 0;
-    for (size_t& offset : bpe_offsets) {
+    for (size_t & offset : bpe_offsets) {
         bpe_words.emplace_back();
         for (size_t i = start; i < start + offset; ++i) {
             bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);
diff --git a/src/unicode.h b/src/unicode.h
index 48940239c..0a5fa2a78 100644
--- a/src/unicode.h
+++ b/src/unicode.h
@@ -4,9 +4,7 @@
 #include 
 #include 
 
-// TODO: prefix all symbols with "llama_"
-
-struct codepoint_flags {
+struct unicode_cpt_flags {
     enum {
         UNDEFINED       = 0x0001,
         NUMBER          = 0x0002,  // regex: \p{N}
@@ -35,7 +33,7 @@ struct codepoint_flags {
     uint16_t is_nfd         : 1;
 
     // decode from uint16
-    inline codepoint_flags(const uint16_t flags=0) {
+    inline unicode_cpt_flags(const uint16_t flags = 0) {
         *reinterpret_cast(this) = flags;
     }
 
@@ -50,19 +48,20 @@ struct codepoint_flags {
 
 size_t unicode_len_utf8(char src);
 
-std::string unicode_cpt_to_utf8(uint32_t cp);
-uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
+std::string unicode_cpt_to_utf8  (uint32_t cpt);
+uint32_t    unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
+
 std::vector unicode_cpts_from_utf8(const std::string & utf8);
 
 std::vector unicode_cpts_normalize_nfd(const std::vector & cpts);
 
-codepoint_flags unicode_cpt_flags(const uint32_t cp);
-codepoint_flags unicode_cpt_flags(const std::string & utf8);
+unicode_cpt_flags unicode_cpt_flags_from_cpt (uint32_t cpt);
+unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8);
 
 std::string unicode_byte_to_utf8(uint8_t byte);
-uint8_t unicode_utf8_to_byte(const std::string & utf8);
+uint8_t     unicode_utf8_to_byte(const std::string & utf8);
 
-uint32_t unicode_tolower(uint32_t cp);
+uint32_t unicode_tolower(uint32_t cpt);
 
 bool unicode_cpt_is_han(uint32_t cpt);