diff --git a/common/chat-parser.h b/common/chat-parser.h index 7c660e539..1e7a3f949 100644 --- a/common/chat-parser.h +++ b/common/chat-parser.h @@ -24,9 +24,9 @@ class common_chat_msg_parser { std::string prelude; std::vector groups; }; - + common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax); - + // Accessors const std::string & input() const { return input_; } size_t pos() const { return pos_; } @@ -42,7 +42,7 @@ class common_chat_msg_parser { } pos_ = pos; } - + void move_back(size_t n) { if (pos_ < n) { throw std::runtime_error("Can't move back that far!"); @@ -56,46 +56,46 @@ class common_chat_msg_parser { // Content manipulation void add_content(const std::string & content); void add_reasoning_content(const std::string & reasoning_content); - + // Tool call manipulation void add_tool_call(const common_chat_tool_call & tool_call); bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments); bool add_tool_call(const json & tool_call); bool add_tool_calls(const json & arr); void clear_tools(); - + // Parsing utilities std::string consume_rest(); bool try_consume_literal(const std::string & literal); void consume_literal(const std::string & literal); bool try_parse_reasoning(const std::string & start_think, const std::string & end_think); - + // Regex-based parsing methods (new) std::optional try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true); find_regex_result consume_regex(const common_regex & regex); std::optional try_consume_regex(const common_regex & regex); - + // Progressive parsing primitives (for Phase 4) std::optional try_find_literal(const std::string & literal); bool consume_spaces(); void set_healing_marker(const std::string & marker); - - + + // Main parsing entry point void parse(); - + // Finishing void finish(); - + // Result extraction common_chat_msg result_and_reset(); - + // Advanced JSON parsing (following original llama.cpp patterns) struct consume_json_result { json value; bool is_partial; }; - + std::optional try_consume_json(); common_json consume_json(); consume_json_result consume_json_with_dumped_args( @@ -112,8 +112,8 @@ class common_chat_msg_parser { void parse_kimi_k2_format(); void parse_deepseek_r1_format(); void parse_generic_format(); - - + + // JSON parsing utilities (enhanced streaming support) struct json_parse_result { json value; @@ -121,11 +121,11 @@ class common_chat_msg_parser { bool is_partial; std::string healing_marker; }; - + // Partial detection utilities bool detect_partial_function_call(const std::string& content); void handle_partial_detection(); - + // Legacy find_literal for compatibility std::optional try_find_literal_legacy(const std::string & literal); }; @@ -133,4 +133,4 @@ class common_chat_msg_parser { // Main parsing function (public API) common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax); -// Content-only parsing for fallback scenarios (static internal function) \ No newline at end of file +// Content-only parsing for fallback scenarios (static internal function) diff --git a/common/chat.cpp b/common/chat.cpp index f62c28011..8be086333 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -220,7 +220,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { // Check for the new tools array format first (no DeepSeek markers) auto original_pos = builder.pos(); - + // First, try the tools array format for content like "function\n```json\n{"tools": [...]}" if (builder.try_find_regex(function_regex_simple)) { builder.move_to(original_pos); @@ -231,7 +231,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { // Fall through to try standard DeepSeek patterns } } - + // If tools array format didn't work, try XML-wrapped format builder.move_to(original_pos); try { @@ -240,7 +240,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { } catch (const common_chat_msg_partial_exception&) { // Fall through to try standard DeepSeek patterns } - + // If XML wrapper format didn't work, try standard DeepSeek patterns builder.move_to(original_pos); try { @@ -278,7 +278,7 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { throw; // Re-throw for partial mode } } - + // Add any remaining content (critical for responses without tool calls) builder.add_content(builder.consume_rest()); } @@ -286,19 +286,19 @@ void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { // Parse DeepSeek R1 tools array format following original llama.cpp parse_prefixed_json_tool_call_array pattern static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) { static const common_regex prefix("function\n```json\n"); - - + + if (auto res = builder.try_find_regex(prefix)) { // Parse JSON and manually process tools array to convert arguments to strings auto json_result = builder.try_consume_json(); if (!json_result) { throw common_chat_msg_partial_exception("invalid JSON"); } - - + + // DeepSeek R1 format has "tools" array, manually process each tool if (json_result->json.contains("tools") && json_result->json.at("tools").is_array()) { - + // Manually create tool calls array with string arguments (following original pattern) json tools_with_dumped_args = json::array(); for (const auto& tool : json_result->json.at("tools")) { @@ -310,15 +310,15 @@ static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) { tools_with_dumped_args.push_back(formatted_tool); } } - - + + if (!builder.add_tool_calls(tools_with_dumped_args) || !json_result->healing_marker.marker.empty()) { throw common_chat_msg_partial_exception("incomplete tool call array"); } } else { throw common_chat_msg_partial_exception("tools key not found or not array"); } - + // Consume closing ``` builder.try_consume_regex(common_regex("```")); } else { @@ -326,41 +326,41 @@ static void parse_deepseek_r1_tools_array(common_chat_msg_parser & builder) { } } -// Parse DeepSeek R1 XML-wrapped format following original Hermes-2-Pro pattern +// Parse DeepSeek R1 XML-wrapped format following original Hermes-2-Pro pattern static void parse_deepseek_r1_xml_wrapped(common_chat_msg_parser & builder) { - + // Pattern for: \nfunctionFunctionName\n```json\n{...}\n```\n static const common_regex xml_pattern( "\\s*" // Opening XML tag - "function([^\\n]+)" // Function name after "function" + "function([^\\n]+)" // Function name after "function" "\\s*```json\\s*" // JSON block start ); - + if (auto res = builder.try_find_regex(xml_pattern)) { - + // Extract function name from capture group std::string function_name = builder.str(res->groups[1]); - + // Parse JSON arguments auto json_result = builder.try_consume_json(); if (!json_result) { throw common_chat_msg_partial_exception("invalid JSON in XML wrapper"); } - - + + // Create single tool call following original pattern json tool_call; tool_call["name"] = function_name; tool_call["arguments"] = json_result->json.dump(); // Convert to string - + json tool_calls_array = json::array(); tool_calls_array.push_back(tool_call); - - + + if (!builder.add_tool_calls(tool_calls_array) || !json_result->healing_marker.marker.empty()) { throw common_chat_msg_partial_exception("incomplete XML wrapped tool call"); } - + // Consume closing ```\n builder.try_consume_regex(common_regex("```\\s*")); } else { @@ -384,6 +384,15 @@ static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) { builder.add_content(kimi_k2::clean_content(builder.input())); } +static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { + // TODO @ngxson : this won't work with --special enabled, we should fix that + builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>"); + if (!builder.syntax().enable_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } +} + // Main parsing dispatch function static void common_chat_parse(common_chat_msg_parser & builder) { switch (builder.syntax().format) { @@ -399,6 +408,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_KIMI_K2: common_chat_parse_kimi_k2(builder); break; + case COMMON_CHAT_FORMAT_GPT_OSS: + common_chat_parse_gpt_oss(builder); + break; default: throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } @@ -432,6 +444,19 @@ const char* common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_GENERIC: return "generic"; case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "deepseek_r1"; case COMMON_CHAT_FORMAT_KIMI_K2: return "kimi_k2"; + case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS"; default: return "unknown"; } -} \ No newline at end of file +} + +const char * common_reasoning_format_name(common_reasoning_format format) { + switch (format) { + case COMMON_REASONING_FORMAT_NONE: return "none"; + case COMMON_REASONING_FORMAT_AUTO: return "auto"; + case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek"; + case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy"; + default: + throw std::runtime_error("Unknown reasoning format"); + } +} + diff --git a/common/chat.h b/common/chat.h index 5899ef1a1..83e31566d 100644 --- a/common/chat.h +++ b/common/chat.h @@ -13,20 +13,20 @@ struct common_chat_templates; struct common_string_range { size_t begin; size_t end; - + common_string_range(size_t begin, size_t end) : begin(begin), end(end) { if (begin > end) { throw std::runtime_error("Invalid range"); } } - + // prevent default ctor common_string_range() = delete; - + bool empty() const { return begin == end; } - + bool operator==(const common_string_range & other) const { return begin == other.begin && end == other.end; } @@ -40,7 +40,7 @@ struct common_chat_tool_call { bool operator==(const common_chat_tool_call & other) const { return name == other.name && arguments == other.arguments && id == other.id; } - + bool operator!=(const common_chat_tool_call & other) const { return !(*this == other); } @@ -65,10 +65,10 @@ struct common_chat_msg { std::string tool_call_id; bool empty() const { - return content.empty() && content_parts.empty() && tool_calls.empty() && + return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); } - + void ensure_tool_call_ids_set(std::vector & ids_cache, const std::function & gen_tool_call_id) { for (auto i = 0u; i < tool_calls.size(); i++) { if (ids_cache.size() <= i) { @@ -91,7 +91,7 @@ struct common_chat_msg { && tool_name == other.tool_name && tool_call_id == other.tool_call_id; } - + bool operator!=(const common_chat_msg & other) const { return !(*this == other); } @@ -110,7 +110,7 @@ struct common_chat_msg_diff { && tool_call_index == other.tool_call_index && tool_call_delta == other.tool_call_delta; } - + bool operator!=(const common_chat_msg_diff & other) const { return !(*this == other); } @@ -132,18 +132,20 @@ enum common_chat_format { COMMON_CHAT_FORMAT_CONTENT_ONLY, COMMON_CHAT_FORMAT_GENERIC, COMMON_CHAT_FORMAT_DEEPSEEK_R1, + COMMON_CHAT_FORMAT_GPT_OSS, COMMON_CHAT_FORMAT_KIMI_K2, // Our custom format (keep last for backward compatibility) }; enum common_reasoning_format { COMMON_REASONING_FORMAT_NONE, + COMMON_REASONING_FORMAT_AUTO, COMMON_REASONING_FORMAT_DEEPSEEK, COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, }; struct common_chat_syntax { common_chat_format format = COMMON_CHAT_FORMAT_KIMI_K2; - common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO; //COMMON_REASONING_FORMAT_NONE; // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode) bool reasoning_in_content = false; bool thinking_forced_open = false; @@ -165,11 +167,12 @@ class common_chat_msg_partial_exception : public std::runtime_error { // Format detection from chat template common_chat_format common_chat_format_detect(const std::string & chat_template); const char* common_chat_format_name(common_chat_format format); +const char* common_reasoning_format_name(common_reasoning_format format); // Main parsing function (entry point for original llama.cpp compatibility) common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax); -// Forward declare parser class +// Forward declare parser class class common_chat_msg_parser; // Format-specific parsing functions (accessible from chat-parser) diff --git a/examples/sweep-bench/sweep-bench.cpp b/examples/sweep-bench/sweep-bench.cpp index 31dd3ce0e..865f65f06 100644 --- a/examples/sweep-bench/sweep-bench.cpp +++ b/examples/sweep-bench/sweep-bench.cpp @@ -61,7 +61,7 @@ int main(int argc, char ** argv) { const llama_vocab * vocab = llama_get_vocab(ctx); - llama_token bos = llama_token_bos_impl(*vocab); + llama_token bos = vocab->token_bos(); //llama_token eos = llama_token_eos_impl(*vocab); const unsigned int n_vocab = llama_n_vocab(model); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 2c2613928..d6350f6e2 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -325,6 +325,16 @@ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ GGML_TENSOR_LOCALS(size_t, nb, dst, nb) +#define GGML_TENSOR_TERNARY_OP_LOCALS \ + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \ + GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + #define GGML_TENSOR_BINARY_OP_LOCALS01 \ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ @@ -571,6 +581,7 @@ extern "C" { GGML_OP_DUP, GGML_OP_ADD, + GGML_OP_ADD_ID, GGML_OP_ADD1, GGML_OP_ACC, GGML_OP_SUB, @@ -674,6 +685,7 @@ extern "C" { GGML_UNARY_OP_HARDSWISH, GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_SWIGLU, + GGML_UNARY_OP_SWIGLU_OAI, GGML_UNARY_OP_COUNT, }; @@ -1028,6 +1040,13 @@ extern "C" { struct ggml_tensor * b, enum ggml_type type); + // dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]] + GGML_API struct ggml_tensor * ggml_add_id( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * ids); + GGML_API struct ggml_tensor * ggml_add1( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1268,6 +1287,13 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_swiglu_oai( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float alpha, + float limit); + // a - x // b - dy GGML_API struct ggml_tensor * ggml_silu_back( @@ -1370,6 +1396,16 @@ extern "C" { struct ggml_tensor * ids, enum ggml_unary_op op); + GGML_API struct ggml_tensor * ggml_moe_up_gate_ext( + struct ggml_context * ctx, + struct ggml_tensor * a_up, + struct ggml_tensor * a_gate, + struct ggml_tensor * b, + struct ggml_tensor * ids, + struct ggml_tensor * a_up_b, + struct ggml_tensor * a_gate_b, + enum ggml_unary_op op); + // A: m columns, n rows, // B: p columns, n rows, // result is m columns, p rows @@ -1662,6 +1698,11 @@ extern "C" { float scale, float max_bias); + GGML_API void ggml_soft_max_add_sinks( + struct ggml_tensor * a, + struct ggml_tensor * sinks); + + GGML_API struct ggml_tensor * ggml_soft_max_back( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1998,6 +2039,10 @@ extern "C" { struct ggml_tensor * a, enum ggml_prec prec); + GGML_API void ggml_flash_attn_ext_add_sinks( + struct ggml_tensor * a, + struct ggml_tensor * sinks); + // TODO: needs to be adapted to ggml_flash_attn_ext GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 67c1ba18f..5bdfa9420 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -43,6 +43,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) { case GGML_OP_DIAG_MASK_ZERO: case GGML_OP_DIAG_MASK_INF: case GGML_OP_ADD: + case GGML_OP_ADD_ID: case GGML_OP_ADD1: case GGML_OP_SUB: case GGML_OP_MUL: diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 9372c05c1..d3868f2ef 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -37,6 +37,7 @@ #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/conv-transpose-1d.cuh" +#include "ggml-cuda/add-id.cuh" #include #include @@ -2220,6 +2221,24 @@ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_origin } } +//static __global__ void k_quick_add(uint32_t n, uint32_t n_per_row, const float * src1, const float * src2, float * dst) { +// +// for (uint32_t j = threadIdx.x; j < n; j += blockDim.x) { +// dst[j] = src1[j] + src2[j % n_per_row]; +// } +//} + +static __global__ void k_quick_add(uint32_t n_per_row, const float * src1, const float * src2, float * dst) { + + uint32_t row = blockIdx.x; + const float * src1_row = src1 + row*n_per_row; + float * dst_row = dst + row*n_per_row; + + for (uint32_t j = threadIdx.x; j < n_per_row; j += blockDim.x) { + dst_row[j] = src1_row[j] + src2[j]; + } +} + static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids, const ggml_tensor * ids, std::vector& moe_counts, std::vector& cum_moe_counts, ggml_cuda_pool_alloc& dev_row_mapping) { @@ -2270,7 +2289,7 @@ static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n return is_ser; } -static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * ids = dst->src[2]; @@ -2319,7 +2338,25 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * 0, src0->ne[1], 1, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); - return; + if (next && next->op == GGML_OP_MUL_MAT_ID && next->src[0]->type == src0->type && src1 == next->src[1] && + ggml_are_same_shape(src0, next->src[0]) && + ggml_backend_buffer_is_cuda(next->src[0]->buffer) && + ggml_backend_buffer_is_cuda(next->buffer) && + !ggml_backend_buffer_is_cuda_split(next->src[0]->buffer)) { + ggml_backend_cuda_buffer_context * next_src0_ctx = (ggml_backend_cuda_buffer_context *) next->src[0]->buffer->context; + ggml_backend_cuda_buffer_context * next_dst_ctx = (ggml_backend_cuda_buffer_context *) next->buffer->context; + if (next_src0_ctx->device == device_id && + next_dst_ctx->device == device_id) { + local_dst.data = next->data; + ggml_cuda_op_mul_mat_vec_q_id(ctx, next->src[0], &local_src1, ids, &local_dst, + (const char *)next->src[0]->data, nullptr, src1_quantized.get(), (float *)next->data, + 0, src0->ne[1], 1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + return true; + } + } + + return false; } } @@ -2442,6 +2479,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } } + return false; } static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) { @@ -2470,6 +2508,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor src0_2_ctx->device == device_id && src1_ctx->device == device_id && dst_ctx->device == device_id) { + //printf("%s(%s, %s): %ld x %ld x %ld, %ld x %ld x %ld, %ld x %ld x %ld\n", __func__, src0_1->name, src0_2->name, + // src0->ne[0], src0->ne[1], src0->ne[2], src1->ne[0], src1->ne[1], src1->ne[2], ids->ne[0], ids->ne[1], ids->ne[2]); // Fast TG path const int64_t n_ids = ids->ne[0]; auto stream = ctx.stream(device_id, 0); @@ -2505,12 +2545,26 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor 0, src0_1->ne[1], 1, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); + if (dst->src[4]) { + ggml_cuda_add_id((const float *)local_dst.data, (const float *)dst->src[4]->data, + (const int32_t *)ids->data, (float *)local_dst.data, + local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2], + local_dst.nb[1], local_dst.nb[2], dst->src[4]->nb[1], ids->nb[2], stream); + } + local_dst.data = dst_gate_contiguous.get(); ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_2, &local_src1, ids, &local_dst, (const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_gate_contiguous.get(), 0, src0_2->ne[1], 1, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); + if (dst->src[5]) { + ggml_cuda_add_id((const float *)local_dst.data, (const float *)dst->src[5]->data, + (const int32_t *)ids->data, (float *)local_dst.data, + local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2], + local_dst.nb[1], local_dst.nb[2], dst->src[5]->nb[1], ids->nb[2], stream); + } + if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && ggml_backend_buffer_is_cuda(next->src[0]->buffer) && !ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) && @@ -2518,8 +2572,15 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor ggml_backend_buffer_is_cuda(next->buffer) && ((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id) { - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst->ne[0]*n_ids, - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); + auto unary_op = (ggml_unary_op)dst->op_params[0]; + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get(), dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, unary_op, dst->ne[0]*n_ids, + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get()); + } CUDA_CHECK(cudaGetLastError()); const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING); @@ -2555,8 +2616,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor return true; } else { CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream)); - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst), - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data); + auto unary_op = (ggml_unary_op)dst->op_params[0]; + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst->data, dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, unary_op, ggml_nelements(dst), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data); + } CUDA_CHECK(cudaGetLastError()); return false; } @@ -2624,7 +2691,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor final_src.nb[3] = final_src.nb[2]; } - if (ne12 == 1) { + if (false && ne12 == 1) { ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); if (fuse_down) { @@ -2761,6 +2828,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor } CUDA_CHECK(cudaGetLastError()); + if (dst->src[4]) { + dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); + dim3 grid_dims(num_src1_rows); + k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, + (const float *)((const char *)dst->src[4]->data + i02*dst->src[4]->nb[1]), (float *)dst_row.data); + CUDA_CHECK(cudaGetLastError()); + } + dst_row.data = dst_gate_contiguous.get(); if (use_quantized_src1) { ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, @@ -2770,8 +2845,24 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor } CUDA_CHECK(cudaGetLastError()); - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); + if (dst->src[5]) { + dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); + dim3 grid_dims(num_src1_rows); + k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, + (const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data); + CUDA_CHECK(cudaGetLastError()); + } + + auto unary_op = (ggml_unary_op)dst->op_params[0]; + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0], + 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get()); + } CUDA_CHECK(cudaGetLastError()); if (fuse_down) { @@ -2851,6 +2942,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ADD: ggml_cuda_op_add(ctx, dst); break; + case GGML_OP_ADD_ID: + ggml_cuda_op_add_id(ctx, dst); + break; case GGML_OP_MULTI_ADD: ggml_cuda_op_multi_add(ctx, dst); break; @@ -2877,6 +2971,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_SWIGLU: ggml_cuda_op_swiglu(ctx, dst); break; + case GGML_UNARY_OP_SWIGLU_OAI: + ggml_cuda_op_swiglu_oai(ctx, dst); + break; case GGML_UNARY_OP_GELU_QUICK: ggml_cuda_op_gelu_quick(ctx, dst); break; @@ -2938,7 +3035,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg } break; case GGML_OP_MUL_MAT_ID: - ggml_cuda_mul_mat_id(ctx, dst); + skip_next = ggml_cuda_mul_mat_id(ctx, dst, next); break; case GGML_OP_MOE_FUSED_UP_GATE: skip_next = ggml_cuda_up_gate_unary(ctx, dst, next); @@ -3440,6 +3537,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SWIGLU: + case GGML_UNARY_OP_SWIGLU_OAI: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSIGMOID: @@ -3629,6 +3727,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: case GGML_OP_ADD: + case GGML_OP_ADD_ID: case GGML_OP_MULTI_ADD: case GGML_OP_MUL: case GGML_OP_DIV: diff --git a/ggml/src/ggml-cuda/add-id.cu b/ggml/src/ggml-cuda/add-id.cu new file mode 100644 index 000000000..34e66954b --- /dev/null +++ b/ggml/src/ggml-cuda/add-id.cu @@ -0,0 +1,72 @@ +#include "add-id.cuh" + +static __global__ void add_id_kernel( + const float * src0, const float * src1, const int32_t * src2, float * dst, + int64_t ne0, int64_t ne1, + size_t nb01, size_t nb02, + size_t nb11, + size_t nb21 + ) { + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.y; + + const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21); + + const size_t nb1 = ne0 * sizeof(float); + const size_t nb2 = ne1 * nb1; + + float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2); + const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02); + const float * src1_row = (const float *)((char *)src1 + i11*nb11); + + for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { + dst_row[i0] = src0_row[i0] + src1_row[i0]; + } +} + +void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + GGML_TENSOR_TERNARY_OP_LOCALS + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_I32); + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb20 == sizeof(int32_t)); + + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + const int32_t * src2_d = (const int32_t *)src2->data; + float * dst_d = (float *)dst->data; + + int threads = std::min((int)ne00, 768); // cols + dim3 blocks(ne01, ne02); // n_experts_used, n_tokens + add_id_kernel<<>>( + src0_d, src1_d, src2_d, dst_d, + ne0, ne1, + nb01, nb02, + nb11, + nb21 + ); +} + +void ggml_cuda_add_id(const float * src0, const float * src1, const int32_t * src2, float * dst, + int64_t ne00, int64_t ne01, int64_t ne02, + int64_t ne0, int64_t ne1, size_t nb01, size_t nb02, size_t nb11, size_t nb21, cudaStream_t stream) { + int threads = std::min((int)ne00, 768); // cols + dim3 blocks(ne01, ne02); // n_experts_used, n_tokens + add_id_kernel<<>>( + src0, src1, src2, dst, + ne0, ne1, + nb01, nb02, + nb11, + nb21 + ); +} diff --git a/ggml/src/ggml-cuda/add-id.cuh b/ggml/src/ggml-cuda/add-id.cuh new file mode 100644 index 000000000..175d6800e --- /dev/null +++ b/ggml/src/ggml-cuda/add-id.cuh @@ -0,0 +1,8 @@ +#include "common.cuh" + +void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_add_id(const float * src0, const float * src1, const int32_t * src2, float * dst, + int64_t ne00, int64_t ne01, int64_t ne02, + int64_t ne0, int64_t ne1, size_t nb01, size_t nb02, size_t nb11, size_t nb21, cudaStream_t stream); + diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index c856a44b6..55088c335 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -108,6 +108,23 @@ static const char * cu_get_error_str(CUresult err) { #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) #endif +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) +# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \ + do { \ + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \ + const int id = ggml_cuda_get_device(); \ + if (!shared_memory_limit_raised[id]) { \ + CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \ + shared_memory_limit_raised[id] = true; \ + } \ + } while (0) +#else +# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \ + do { \ + GGML_UNUSED(nbytes); \ + } while (0) +#endif // !(defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + #if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) #define GGML_CUDA_ASSUME(x) __builtin_assume(x) #else diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 5e8ee0f63..8e7fadba2 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -22,6 +22,7 @@ typedef void (* fattn_kernel_t)( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -747,6 +748,7 @@ void launch_fattn( const ggml_tensor * V = dst->src[2]; const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; ggml_tensor * KQV = dst; @@ -837,6 +839,7 @@ void launch_fattn( K_data, V_data, mask ? ((const char *) mask->data) : nullptr, + sinks ? ((const char *) sinks->data) : nullptr, (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, scale, max_bias, m0, m1, softcap, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], @@ -1008,7 +1011,8 @@ void launch_fattn_mma( const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; - const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; ggml_tensor * KQV = dst; @@ -1162,6 +1166,7 @@ void launch_fattn_mma( K_data, V_data, mask ? ((const char *) mask->data) : nullptr, + sinks ? ((const char *)sinks->data) : nullptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index af38071d2..14444832c 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -425,6 +425,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, const half2 * const __restrict__ mask_h2, + const float * const __restrict__ sinks_f, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, @@ -584,6 +585,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } + // If attention sinks are used, potentially re-scale if KQ_max is small. + // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum + // so it's being done unconditionally for every thread. + if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) { + float KQ_max_scale[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented"); + const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col); + const float sink = sinks_f[jc % ncols2]; + + const float KQ_max_new = fmaxf(KQ_max[col], sink); + const float KQ_max_diff = KQ_max[col] - KQ_max_new; + KQ_max_scale[col] = expf(KQ_max_diff); + KQ_max[col] = KQ_max_new; + + *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD; + + const float KQ_max_add = expf(sink - KQ_max_new); + KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add; + } + + if (ntiles == 1) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); +#pragma unroll + for (int i = 0; i < D/tile_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { + VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + } + } + } + } + } + // Write VKQ accumulators to shared memory in column-major format. // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // Also for np > 1 the combination is done via these values in shared memory. @@ -823,6 +870,7 @@ static __global__ void flash_attn_mma_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -896,6 +944,7 @@ static __global__ void flash_attn_mma_ext_f16( const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + const float * sinks_f = sinks ? (const float *) sinks + channel * ncols2 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; @@ -906,12 +955,12 @@ static __global__ void flash_attn_mma_ext_f16( if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } @@ -934,6 +983,7 @@ static __global__ void flash_attn_mma_ext_f16( const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; @@ -943,10 +993,10 @@ static __global__ void flash_attn_mma_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 6178b3e5a..c37a618ff 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -43,37 +43,37 @@ struct fattn_mma_f16_config; // Perhaps the 256 head size needs a closer look // to see if this implementation is better. // -//template <> -//struct fattn_mma_f16_config< 64, 64> { -// static constexpr int nbatch_fa = 64; -// static constexpr int nwarps_max = 4; -// static constexpr bool Q_in_reg = true; -// static constexpr int nstages_target = 2; -// -// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { -// return 32; -// } -// -// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { -// return 32; -// } -// -// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { -// return 32; -// } -// -// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { -// return 32; -// } -// -// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { -// return 32; -// } -// -// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { -// return 32; -// } -//}; +template <> +struct fattn_mma_f16_config< 64, 64> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 32; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 32; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 32; + } +}; // //template <> //struct fattn_mma_f16_config< 80, 80> { @@ -493,7 +493,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V); } else { constexpr bool use_cp_async = nstages == 1; - if constexpr (ncols2 > 1 || mask_h2) { + if (ncols2 > 1 || mask_h2) { flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); } } @@ -576,7 +576,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( float KQ_rowsum_add[cols_per_thread] = {0.0f}; if constexpr (ntiles == 1) { - if constexpr (ncols2 > 1 || mask_h2) { + if (ncols2 > 1 || mask_h2) { #pragma unroll for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) { const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; @@ -818,6 +818,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, const half2 * const __restrict__ mask_h2, + const float * const __restrict__ sinks_f, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, @@ -975,6 +976,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); } + // If attention sinks are used, potentially re-scale if KQ_max is small. + // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum + // so it's being done unconditionally for every thread. + if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) { + float KQ_max_scale[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented"); + const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col); + const float sink = sinks_f[jc % ncols2]; + + const float KQ_max_new = fmaxf(KQ_max[col], sink); + const float KQ_max_diff = KQ_max[col] - KQ_max_new; + KQ_max_scale[col] = expf(KQ_max_diff); + KQ_max[col] = KQ_max_new; + + *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD; + + const float KQ_max_add = expf(sink - KQ_max_new); + KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add; + } + + if (ntiles == 1) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); +#pragma unroll + for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { + VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + } + } + } + } + } + // Finally, sum up partial KQ rowsums. // The partial sums are spread across 8/4 threads each, does not need full reduce. { @@ -1222,7 +1269,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } #else - GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); + GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); GGML_UNUSED(sinks_f); GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); @@ -1239,6 +1286,7 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -1323,6 +1371,7 @@ static __global__ void flash_attn_ext_f16( const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr; const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); @@ -1335,12 +1384,12 @@ static __global__ void flash_attn_ext_f16( if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } @@ -1362,6 +1411,7 @@ static __global__ void flash_attn_ext_f16( const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr; const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); @@ -1373,7 +1423,7 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); @@ -1535,7 +1585,8 @@ static void launch_fattn_new_mma( const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; - const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; ggml_tensor * KQV = dst; @@ -1709,6 +1760,7 @@ static void launch_fattn_new_mma( K_data, V_data, mask ? ((const char *) mask->data) : nullptr, + sinks ? ((const char *)sinks->data) : nullptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, logit_softcap, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], @@ -1853,6 +1905,11 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; + if (use_gqa_opt && gqa_ratio % 16 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + if (use_gqa_opt && gqa_ratio % 8 == 0) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; @@ -1878,8 +1935,6 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens const ggml_tensor * V = dst->src[2]; const ggml_tensor * mask = dst->src[3]; - GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512); - float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); @@ -1888,6 +1943,12 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; + + if (K->ne[0] == 64 && V->ne[0] == 64) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst); + return; + } + GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512); GGML_ASSERT(gqa_ratio % 16 == 0); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 420f0bb08..c79b4821d 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index f525f1bbf..3d1926ce4 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f32( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 2cf4f4ef9..95dd0e963 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -71,6 +72,7 @@ static __global__ void flash_attn_vec_ext_f16( V += nb22*(blockIdx.y / gqa_ratio); const half * maskh = (const half *) mask + ne11*ic0; + const float * sinksf = (const float *) (sinks); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); @@ -270,6 +272,39 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); } + if (sinksf) { + const half sink = __float2half(sinksf[blockIdx.y]); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.x == 0) { + kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + half kqmax_new_j = kqmax_shared[j][threadIdx.y]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; + + const half val = hexp(sink - kqmax[j]); + kqsum[j] = kqsum[j]*KQ_max_scale; + + if (tid == 0) { + kqsum[j] += val; + } + + VKQ[j] *= __half2half2(KQ_max_scale); + } + + __syncthreads(); + } + #pragma unroll for (int j = 0; j < ncols; ++j) { kqsum[j] = warp_reduce_sum(kqsum[j]); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index c91cef3dc..a97b3737d 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f32( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -69,6 +70,7 @@ static __global__ void flash_attn_vec_ext_f32( K += nb12*(blockIdx.y / gqa_ratio); V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; + const float * sinksf = (const float *) (sinks); const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); @@ -254,6 +256,39 @@ static __global__ void flash_attn_vec_ext_f32( __syncthreads(); } + if (sinksf) { + const float sink = sinksf[blockIdx.y]; + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.x == 0) { + kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + float kqmax_new_j = kqmax_shared[j][threadIdx.y]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; + + const float val = expf(sink - kqmax[j]); + kqsum[j] = kqsum[j]*KQ_max_scale; + + if (tid == 0) { + kqsum[j] += val; + } + + VKQ[j] *= KQ_max_scale; + } + + __syncthreads(); + } + #pragma unroll for (int j = 0; j < ncols; ++j) { kqsum[j] = warp_reduce_sum(kqsum[j]); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index d39c6a6e5..e1e2ec6ed 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -5,6 +5,8 @@ // SPDX-License-Identifier: MIT // +// TODO: attention sinks !!! + #include "common.cuh" #include "fattn-common.cuh" @@ -22,6 +24,7 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -93,6 +96,7 @@ static __global__ void flash_attn_ext_f16( const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); + const float * sinks_f = sinks ? (const float *)sinks + blockIdx.y : nullptr; const int stride_Q = nb01 / sizeof(float); const int stride_K = nb11 / sizeof(half); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 725b443dd..ffcaf219b 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -539,7 +539,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - // As mentioned above, the new new MMA is slower than then the new MMA. + // As mentioned above, the new-new MMA is slower then the new MMA. ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); //ggml_cuda_flash_attn_ext_mma_new(ctx, dst); } diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index c006301fa..4c9aa7edf 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -19,7 +19,7 @@ __device__ float __forceinline__ t2f32(half val) { } template -static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, float cap_params0, float cap_params1, bool do_softcap) { +static __global__ void soft_max_f32_nosinks(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, float cap_params0, float cap_params1, bool do_softcap) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -124,7 +124,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst } template -static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) { +static void soft_max_f32_cuda_nosinks(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -142,39 +142,40 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { switch (ncols_x) { case 32: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 64: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 128: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 256: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 512: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 1024: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 2048: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; case 4096: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; default: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); break; } } else { const size_t shmem_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); + soft_max_f32_nosinks<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap); } } +#if 0 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -205,13 +206,14 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (use_f16) { const half * src1_dd = (const half *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); + soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); } else { const float * src1_dd = (const float *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); + soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream); } } +#endif void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; @@ -241,10 +243,283 @@ void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * ds if (use_f16) { const half * src1_dd = (const half *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); + soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); } else { const float * src1_dd = (const float *)src1_d; - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); + soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream); } } + +struct soft_max_params { + + int64_t nheads; + uint32_t n_head_log2; + int64_t ncols; + int64_t nrows_x; + int64_t nrows_y; + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + int64_t nb11; + int64_t nb12; + int64_t nb13; + + int64_t ne12; + int64_t ne13; + float scale; + float max_bias; + float m0; + float m1; +}; + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template +static __global__ void soft_max_f32( + const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) { + const int ncols = ncols_template == 0 ? p.ncols : ncols_template; + + const int tid = threadIdx.x; + + const int64_t i03 = blockIdx.z; + const int64_t i02 = blockIdx.y; + const int64_t i01 = blockIdx.x; + + //TODO: noncontigous inputs/outputs + const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y; + + const int64_t i11 = i01; + const int64_t i12 = i02 % p.ne12; + const int64_t i13 = i03 % p.ne13; + + x += int64_t(rowx)*ncols; + mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr); + dst += int64_t(rowx)*ncols; + + const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + + const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1); + + extern __shared__ float data_soft_max_f32[]; + float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication + // shared memory buffer to cache values between iterations: + float * vals = use_shared ? buf_iw + WARP_SIZE : dst; + + float max_val = sinks ? sinks[i02] : -INFINITY; + +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + break; + } + + const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f); + + vals[col] = val; + max_val = max(max_val, val); + } + + // find the max value in the block + max_val = warp_reduce_max(max_val); + if (block_size > WARP_SIZE) { + if (warp_id == 0) { + buf_iw[lane_id] = -INFINITY; + } + __syncthreads(); + + if (lane_id == 0) { + buf_iw[warp_id] = max_val; + } + __syncthreads(); + + max_val = buf_iw[lane_id]; + max_val = warp_reduce_max(max_val); + } + + float tmp = 0.0f; // partial sum +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + break; + } + + const float val = expf(vals[col] - max_val); + tmp += val; + vals[col] = val; + } + + // find the sum of exps in the block + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __syncthreads(); + if (warp_id == 0) { + buf_iw[lane_id] = 0.0f; + } + __syncthreads(); + + if (lane_id == 0) { + buf_iw[warp_id] = tmp; + } + __syncthreads(); + + tmp = buf_iw[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + if (sinks) { + tmp += expf(sinks[i02] - max_val); + } + + const float inv_sum = 1.0f / tmp; + +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + return; + } + + dst[col] = vals[col] * inv_sum; + } +} +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ + +template +static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst, + const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared) +{ + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + + auto launch_kernel = [=](auto I) -> bool { + constexpr int ncols = decltype(I)::value; + constexpr int block = (ncols > 1024 ? 1024 : ncols); + + if (p.ncols == ncols) { + CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32), smpbo); + soft_max_f32<<>> + (x, mask, sinks, dst, p); + return true; + } + return false; + }; + + // unary fold over launch_kernel + if ((launch_kernel(std::integral_constant{}) || ...)) { + return; + } + + //default case + CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32), smpbo); + soft_max_f32<<>>(x, mask, sinks, dst, p); +} + +template +static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) { + int nth = WARP_SIZE; + const int64_t ncols_x = params.ncols; + + while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; + const dim3 block_dims(nth, 1, 1); + const dim3 block_nums(params.ne01, params.ne02, params.ne03); + const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); + static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); + + + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + + + if (nbytes_shared <= smpbo) { + launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared); + } else { + const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); + soft_max_f32<<>>(x, mask, sinks, dst, params); + } +} + +void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + const float * src0_d = (const float *) src0->data; + const void * src1_d = src1 ? (const void *) src1->data : nullptr; + const void * src2_d = src2 ? (const void *) src2->data : nullptr; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; + + const int64_t ne00 = src0->ne[0]; + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; + + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + + soft_max_params params = {}; + params.nheads = src0->ne[2]; + params.n_head_log2 = n_head_log2; + params.ncols = ne00; + params.nrows_x = nrows_x; + params.nrows_y = nrows_y; + params.ne00 = src0->ne[0]; + params.ne01 = src0->ne[1]; + params.ne02 = src0->ne[2]; + params.ne03 = src0->ne[3]; + params.nb11 = nb11; + params.nb12 = nb12; + params.nb13 = nb13; + params.ne12 = ne12; + params.ne13 = ne13; + params.scale = scale; + params.max_bias = max_bias; + params.m0 = m0; + params.m1 = m1; + + if (use_f16) { + soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream); + } else { + soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream); + } +} + diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 6312f25c5..a6742b980 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -470,3 +470,83 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } + +template +static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) { + const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + // perform base op and multiply with gate (either offset in same tensor or a separate one) + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + + float xi = x[j0]; + float gi = g[j1]; + xi = fminf(xi, limit); + gi = fmaxf(fminf(gi, limit), -limit); + + float out_glu = xi / (1.0f + expf(-xi * alpha)); + out_glu = out_glu * (1.0f + gi); + + dst[i] = out_glu; +} + +template +static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) { + const int64_t num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; + swiglu_oai_kernel<<>>(x, g, dst, k, n, o0, o1, alpha, limit); +} + +void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + void * src0_d = src0->data; + void * src1_d = src1 ? src1->data : src0->data; + const int64_t src0_o = src0->nb[1]; + const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + void * dst_d = dst->data; + const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(src0->nb[0] == ggml_element_size(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == dst->type); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src1->nb[0] == ggml_element_size(src1)); + GGML_ASSERT(src1->ne[0] == nc); + GGML_ASSERT(src0->type == src1->type); + } + + //const int32_t swapped = ((const int32_t *) dst->op_params)[1]; + const int32_t swapped = false; //ggml_get_op_params_i32(dst, 1); + const float * op_params = (const float *)dst->op_params; + const float alpha = op_params[2]; + const float limit = op_params[3]; + + float * src0_p = (float *) src0_d; + float * src1_p = (float *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, + src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream); +} + +void ggml_swiglu_oai_cuda_f32(const float * x, const float * g, float * dst, const int64_t k, const int64_t n, + const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) { + swiglu_oai_cuda(x, g, dst, k, n, o0, o1, alpha, limit, stream); +} diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 9bcd30a84..9da1f8ca2 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -47,3 +47,9 @@ void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op, int64_t nelements, const float * x, const float * y, float * z); void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_swiglu_oai_cuda_f32(const float * x, const float * g, float * dst, const int64_t k, const int64_t n, + const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream); + diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f3a23727b..695dc7227 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2823,7 +2823,6 @@ inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } @@ -2834,6 +2833,19 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } +inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { + int i = 0; +#if defined(__AVX2__) + for (; i + 7 < n; i += 8) { + __m256 vx = _mm256_loadu_ps(x + i); + __m256 vy = _mm256_loadu_ps(y + i); + __m256 vz = _mm256_add_ps(vx, vy); + _mm256_storeu_ps(z + i, vz); + } +#endif + for (; i < n; ++i) z[i] = x[i] + y[i]; +} + static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -4004,6 +4016,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "DUP", "ADD", + "ADD_ID", "ADD1", "ACC", "SUB", @@ -4092,13 +4105,14 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); +static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", "x", "x+y", + "x[i]+y", "x+y", "view(x,nb,offset)+=y->x", "x-y", @@ -4187,7 +4201,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); +static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4207,9 +4221,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "HARDSWISH", "HARDSIGMOID", "SWIGLU", + "SWIGLU_OAI", }; -static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14"); +static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -5917,6 +5932,29 @@ struct ggml_tensor * ggml_add_cast( return ggml_add_cast_impl(ctx, a, b, type); } +// ggml_add_id + +struct ggml_tensor * ggml_add_id( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * ids) { + + GGML_ASSERT(a->ne[0] == b->ne[0]); + GGML_ASSERT(a->ne[1] == ids->ne[0]); + GGML_ASSERT(a->ne[2] == ids->ne[1]); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_ADD_ID; + result->src[0] = a; + result->src[1] = b; + result->src[2] = ids; + + return result; +} + // ggml_add1 static struct ggml_tensor * ggml_add1_impl( @@ -6662,6 +6700,36 @@ struct ggml_tensor * ggml_swiglu( return result; } +struct ggml_tensor * ggml_swiglu_oai( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float alpha, + float limit) { + + GGML_ASSERT(ggml_is_contiguous_1(a)); + if (b) { + GGML_ASSERT(ggml_is_contiguous_1(b)); + GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(a->type == b->type); + } + + int64_t ne[4] = {a->ne[0]/2, a->ne[1], a->ne[2], a->ne[3]}; + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0); + + result->op = GGML_OP_UNARY; + result->grad = NULL; + result->src[0] = a; + result->src[1] = b; + + ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU_OAI); + ggml_set_op_params_f32(result, 2, alpha); + ggml_set_op_params_f32(result, 3, limit); + + return result; +} + // ggml_silu_back struct ggml_tensor * ggml_silu_back( @@ -7017,6 +7085,66 @@ struct ggml_tensor * ggml_moe_up_gate( result->src[1] = as_gate; result->src[2] = b; result->src[3] = ids; + result->src[4] = NULL; + result->src[5] = NULL; + + ggml_set_op_params_i32(result, 0, (int32_t) op); + + return result; +} + +struct ggml_tensor * ggml_moe_up_gate_ext( + struct ggml_context * ctx, + struct ggml_tensor * as_up, + struct ggml_tensor * as_gate, + struct ggml_tensor * b, + struct ggml_tensor * ids, + struct ggml_tensor * as_up_b, + struct ggml_tensor * as_gate_b, + enum ggml_unary_op op) { + + if (!as_up_b && !as_gate_b) { + return ggml_moe_up_gate(ctx, as_up, as_gate, b, ids, op); + } + + if (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate)) { + struct ggml_tensor * result_up = ggml_mul_mat_id(ctx, as_up, b, ids); + if (as_up_b) { + result_up = ggml_add_id(ctx, result_up, as_up_b, ids); + } + struct ggml_tensor * result_gate = ggml_mul_mat_id(ctx, as_gate, b, ids); + if (as_gate_b) { + result_gate = ggml_add_id(ctx, result_gate, as_gate_b, ids); + } + return ggml_fused_mul_unary(ctx, result_gate, result_up, op); + } + + GGML_ASSERT(!ggml_is_transposed(as_up)); + GGML_ASSERT(!ggml_is_transposed(as_gate)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + GGML_ASSERT(as_up->ne[3] == 1); // as is 3d (one matrix per expert) + GGML_ASSERT(b->ne[3] == 1); // b is 3d + GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d + GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row + GGML_ASSERT(as_up->ne[0] == b->ne[0]); // can_mul_mat + GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast + + GGML_ASSERT(as_up->ne[1] == as_up_b->ne[0]); + GGML_ASSERT(as_gate->ne[1] == as_gate_b->ne[0]); + bool is_node = false; + + const int64_t ne[4] = { as_up->ne[1], ids->ne[0], b->ne[2], 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_MOE_FUSED_UP_GATE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = as_up; + result->src[1] = as_gate; + result->src[2] = b; + result->src[3] = ids; + result->src[4] = as_up_b; + result->src[5] = as_gate_b; ggml_set_op_params_i32(result, 0, (int32_t) op); @@ -7970,6 +8098,22 @@ struct ggml_tensor * ggml_soft_max_ext( return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); } +void ggml_soft_max_add_sinks( + struct ggml_tensor * a, + struct ggml_tensor * sinks) { + if (!sinks) { + a->src[2] = NULL; + return; + } + + GGML_ASSERT(a->op == GGML_OP_SOFT_MAX); + GGML_ASSERT(a->src[2] == NULL); + GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]); + GGML_ASSERT(sinks->type == GGML_TYPE_F32); + + a->src[2] = sinks; +} + // ggml_soft_max_back static struct ggml_tensor * ggml_soft_max_back_impl( @@ -8833,6 +8977,22 @@ void ggml_flash_attn_ext_set_prec( ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second } +void ggml_flash_attn_ext_add_sinks( + struct ggml_tensor * a, + struct ggml_tensor * sinks) { + if (!sinks) { + a->src[4] = NULL; + return; + } + + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + GGML_ASSERT(a->src[4] == NULL); + GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]); + GGML_ASSERT(sinks->type == GGML_TYPE_F32); + + a->src[4] = sinks; +} + // ggml_flash_attn_back struct ggml_tensor * ggml_flash_attn_back( @@ -11497,6 +11657,77 @@ static void ggml_compute_forward_multi_add( } } +// ggml_compute_forward_add_id + +static void ggml_compute_forward_add_id_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src2 = dst->src[2]; + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_I32); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_TERNARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + // src1 indices + const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21); + + GGML_ASSERT(i11 >= 0 && i11 < ne11); + + ggml_vec_add_f32(ne0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), + (float *) ((char *) src1->data + i11*nb11)); + } +} + +static void ggml_compute_forward_add_id( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_add_id_f32(params, dst); + } break; + default: + { + GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type)); + } + } +} + // ggml_compute_forward_add1 static void ggml_compute_forward_add1_f32( @@ -13760,6 +13991,93 @@ static void ggml_compute_forward_swiglu( } } +// ggml_compute_forward_swiglu_oai + +static void ggml_compute_forward_swiglu_oai_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = false; //ggml_get_op_params_i32(dst, 1); + const float alpha = ggml_get_op_params_f32(dst, 2); + const float limit = ggml_get_op_params_f32(dst, 3); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1])); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + for (int k = 0; k < nc; k++) { + const float x = MIN(src0_p[k], limit); + const float y = MAX(MIN(src1_p[k], limit), -limit); + const float out_glu = x / (1.f + expf(alpha * (-x))); + dst_p[k] = out_glu * (y + 1.f); + } + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = dst_p[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_swiglu_oai( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_swiglu_oai_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_fused_mul_unary static void ggml_compute_forward_fused_mul_unary_f32( @@ -15167,6 +15485,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate( const struct ggml_tensor * src1 = dst->src[2]; const struct ggml_tensor * ids = dst->src[3]; + const struct ggml_tensor * up_b = dst->src[4]; + const struct ggml_tensor * gate_b = dst->src[5]; const struct ggml_tensor * src0_1 = dst->src[0]; const struct ggml_tensor * src0_2 = dst->src[1]; const struct ggml_tensor * src0 = src0_1; // so GGML_TENSOR_BINARY_OP_LOCALS works @@ -15191,6 +15511,9 @@ static void ggml_compute_forward_mul_mat_id_up_gate( GGML_ASSERT(nb2 <= nb3); GGML_ASSERT(ne13 == 1); + const size_t nb41 = up_b ? up_b->nb[1] : 0; + const size_t nb51 = up_b ? gate_b->nb[1] : 0; + // row groups const int n_ids = ids->ne[0]; // n_expert_used const int n_as = ne02; // n_expert @@ -15278,6 +15601,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate( const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02; const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02; + const char * up_b_cur = up_b ? (const char *)up_b->data + cur_a*nb41 : NULL; + const char * gate_b_cur = gate_b ? (const char *)gate_b->data + cur_a*nb51 : NULL; const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); @@ -15288,6 +15613,7 @@ static void ggml_compute_forward_mul_mat_id_up_gate( if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0], type, src0_1_cur, src0_2_cur, nb01, vec_dot_type, (const char *)wdata, row_size, + up_b_cur, gate_b_cur, (float *)dst->data, nb1, nb2, matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); @@ -16645,6 +16971,7 @@ static void ggml_compute_forward_soft_max_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src2 = dst->src[2]; assert(ggml_is_contiguous(dst)); assert(ggml_are_same_shape(src0, dst)); @@ -16662,6 +16989,13 @@ static void ggml_compute_forward_soft_max_f32( GGML_TENSOR_UNARY_OP_LOCALS + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; + + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; + //const int64_t ne11 = src1 ? src1->ne[1] : 1; // TODO: is this supposed to be ceil instead of floor? @@ -16673,67 +17007,80 @@ static void ggml_compute_forward_soft_max_f32( const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - for (int i1 = ir0; i1 < ir1; i1++) { - // ALiBi - const uint32_t h = (i1/ne01)%ne02; // head - const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + // sinks + const float * sk = src2 ? (float *)((char *) src2->data) : NULL; - float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); - float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); - - // broadcast the mask across rows - ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const int64_t i11 = i01; + const int64_t i12 = i02%ne12; + const int64_t i13 = i03%ne13; + + // ALiBi + const uint32_t h = i02; // head + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + + float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + // broadcast the mask across rows + ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL; + + ggml_vec_cpy_f32 (ne00, wp, sp); + ggml_vec_scale_f32(ne00, wp, scale); + if (mp_f32) { + if (use_f16) { + for (int i = 0; i < ne00; ++i) { + wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); + } + } else { + for (int i = 0; i < ne00; ++i) { + wp[i] += slope*mp_f32[i]; + } + } + } - ggml_vec_cpy_f32 (nc, wp, sp); - ggml_vec_scale_f32(nc, wp, scale); - if (mp_f32) { - if (use_f16) { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); +#ifndef NDEBUG + for (int i = 0; i < ne00; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(wp[i])); } - } else { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*mp_f32[i]; +#endif + + float max = -INFINITY; + ggml_vec_max_f32(ne00, &max, wp); + + // if we have sinks, make a correction as if they were included in the softmax + if (sk) { + max = MAX(max, sk[i02]); } - } - } -//#ifndef NDEBUG -// for (int i = 0; i < nc; ++i) { -// //printf("p[%d] = %f\n", i, p[i]); -// assert(!isnan(wp[i])); -// } -//#endif + ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max); + assert(sum > 0.0); - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, wp); + if (sk) { + sum += (ggml_float) expf(sk[i02] - max); + } - ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); - //assert(sum > 0.0); + sum = 1.0/sum; + ggml_vec_scale_f32(ne00, dp, sum); - sum = 1.0/sum; - ggml_vec_scale_f32(nc, dp, sum); +#ifndef NDEBUG + for (int i = 0; i < ne00; ++i) { + assert(!isnan(dp[i])); + assert(!isinf(dp[i])); + } +#endif -//#ifndef NDEBUG -// for (int i = 0; i < nc; ++i) { -// assert(!isnan(dp[i])); -// assert(!isinf(dp[i])); -// } -//#endif + } + } } } @@ -16755,7 +17102,6 @@ static void ggml_compute_forward_soft_max( } } - // ggml_compute_forward_soft_max_back static void ggml_compute_forward_soft_max_back_f32( @@ -18308,12 +18654,14 @@ static void ggml_compute_forward_argsort_thresh( static void ggml_compute_forward_flash_attn_ext_f16( const struct ggml_compute_params * params, - const struct ggml_tensor * q, - const struct ggml_tensor * k, - const struct ggml_tensor * v, - const struct ggml_tensor * mask, struct ggml_tensor * dst) { + const struct ggml_tensor * q = dst->src[0]; + const struct ggml_tensor * k = dst->src[1]; + const struct ggml_tensor * v = dst->src[2]; + const struct ggml_tensor * mask = dst->src[3]; + const struct ggml_tensor * sinks = dst->src[4]; + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) GGML_TENSOR_LOCALS(size_t, nbq, q, nb) GGML_TENSOR_LOCALS(int64_t, nek, k, ne) @@ -18383,6 +18731,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( } #if GGML_USE_IQK_MULMAT + // For now we do not implement sinks in the iqk FA implementation if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias, q->ne[3], q->ne[2], q->nb[3], q->nb[2], k->ne[3], k->ne[2], k->nb[3], k->nb[2], @@ -18390,7 +18739,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( dst->ne[2], dst->ne[1], dst->nb[1], k->type, v->type, Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], - q->data, k->data, v->data, mask->data, + q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL, scale, softcap, (float *)dst->data, params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return; @@ -18447,6 +18796,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; ggml_to_float_t const v_to_float = type_traits[v->type].to_float; + GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type"); + GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type"); + const int64_t Dkv = MAX(Dk, Dv); // loop over n_batch and n_head @@ -18552,6 +18904,22 @@ static void ggml_compute_forward_flash_attn_ext_f16( } } + if (sinks) { + const float s = ((float *)((char *) sinks->data))[h]; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + ms = expf(M - s); + ggml_vec_scale_f32(Dv, VKQ32, ms); + } else { + vs = expf(s - M); + } + + S = S*ms + vs; + } + // V /= S const float S_inv = 1.0f/S; ggml_vec_scale_f32(Dv, VKQ32, S_inv); @@ -18571,17 +18939,13 @@ static void ggml_compute_forward_flash_attn_ext_f16( static void ggml_compute_forward_flash_attn_ext( const struct ggml_compute_params * params, - const struct ggml_tensor * q, - const struct ggml_tensor * k, - const struct ggml_tensor * v, - const struct ggml_tensor * mask, struct ggml_tensor * dst) { switch (dst->op_params[3]) { case GGML_PREC_DEFAULT: case GGML_PREC_F32: { // uses F32 accumulators - ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); + ggml_compute_forward_flash_attn_ext_f16(params, dst); } break; default: { @@ -19350,6 +19714,10 @@ static void ggml_compute_forward_unary( { ggml_compute_forward_swiglu(params, dst); } break; + case GGML_UNARY_OP_SWIGLU_OAI: + { + ggml_compute_forward_swiglu_oai(params, dst); + } break; case GGML_UNARY_OP_HARDSWISH: { ggml_compute_forward_hardswish(params, dst); @@ -19898,6 +20266,10 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_add(params, tensor); } break; + case GGML_OP_ADD_ID: + { + ggml_compute_forward_add_id(params, tensor); + } break; case GGML_OP_ADD1: { ggml_compute_forward_add1(params, tensor); @@ -20136,7 +20508,7 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_FLASH_ATTN_EXT: { - ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + ggml_compute_forward_flash_attn_ext(params, tensor); } break; case GGML_OP_FLASH_ATTN_BACK: { @@ -20486,6 +20858,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table); } } break; + case GGML_OP_ADD_ID: + { + GGML_ABORT("fatal error"); // TODO: implement + } break; case GGML_OP_ADD1: { if (src0->grad) { @@ -21719,6 +22095,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_DUP: case GGML_OP_CONT: case GGML_OP_ADD: + case GGML_OP_ADD_ID: case GGML_OP_ADD1: case GGML_OP_ACC: case GGML_OP_MULTI_ADD: @@ -21758,6 +22135,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SWIGLU: + case GGML_UNARY_OP_SWIGLU_OAI: { n_tasks = n_threads; } break; @@ -21952,6 +22330,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa } } break; case GGML_OP_ADD: + case GGML_OP_ADD_ID: case GGML_OP_ADD1: { if (ggml_is_quantized(node->src[0]->type)) { diff --git a/ggml/src/iqk/fa/iqk_fa_128_128.cpp b/ggml/src/iqk/fa/iqk_fa_128_128.cpp index 52eb289da..e5db1720f 100644 --- a/ggml/src/iqk/fa/iqk_fa_128_128.cpp +++ b/ggml/src/iqk/fa/iqk_fa_128_128.cpp @@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_128_128) { if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types if (nk%64 == 0) { iqk_flash_helper_T<128, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } iqk_flash_helper_T<128, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } #endif if (nk%128 == 0) { return iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } if (nk%64 == 0) { return iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } return iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } diff --git a/ggml/src/iqk/fa/iqk_fa_192_128.cpp b/ggml/src/iqk/fa/iqk_fa_192_128.cpp index 6c4c51fb6..9cd62afd0 100644 --- a/ggml/src/iqk/fa/iqk_fa_192_128.cpp +++ b/ggml/src/iqk/fa/iqk_fa_192_128.cpp @@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_192_128) { if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types if (nk%64 == 0) { iqk_flash_helper_T<192, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } iqk_flash_helper_T<192, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } #endif if (nk%128 == 0) { return iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } if (nk%64 == 0) { return iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } return iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } diff --git a/ggml/src/iqk/fa/iqk_fa_256_256.cpp b/ggml/src/iqk/fa/iqk_fa_256_256.cpp index b0bc35e34..a8565de2c 100644 --- a/ggml/src/iqk/fa/iqk_fa_256_256.cpp +++ b/ggml/src/iqk/fa/iqk_fa_256_256.cpp @@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_256_256) { if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types if (nk%64 == 0) { iqk_flash_helper_T<256, 256, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } iqk_flash_helper_T<256, 256, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } #endif if (nk%128 == 0) { return iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } if (nk%64 == 0) { return iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } return iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } diff --git a/ggml/src/iqk/fa/iqk_fa_576_512.cpp b/ggml/src/iqk/fa/iqk_fa_576_512.cpp index 5174be30a..9517eaa15 100644 --- a/ggml/src/iqk/fa/iqk_fa_576_512.cpp +++ b/ggml/src/iqk/fa/iqk_fa_576_512.cpp @@ -9,7 +9,8 @@ namespace { template inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, - const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { + const float * q, const char * mask, float scale, float softcap, float * qkv, + const float * sinkf, float * M, float * S) { auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) { nq1 -= n; if (nq1 == 0) return true; @@ -21,29 +22,29 @@ inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, }; if (nq1 >= 16) { int n_step = nq1/16; - FlashAttn<576, 512, 16, step_k> fa(scale, softcap); + FlashAttn<576, 512, 16, step_k> fa(scale, softcap, sinkf); fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); if (update(16*n_step)) return; } if (nq1 >= 8) { int n_step = nq1/8; - FlashAttn<576, 512, 8, step_k> fa(scale, softcap); + FlashAttn<576, 512, 8, step_k> fa(scale, softcap, sinkf); fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); if (update(8*n_step)) return; } if (nq1 >= 4) { int n_step = nq1/4; - FlashAttn<576, 512, 4, step_k> fa(scale, softcap); + FlashAttn<576, 512, 4, step_k> fa(scale, softcap, sinkf); fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); if (update(4*n_step)) return; } if (nq1 >= 2) { int n_step = nq1/2; - FlashAttn<576, 512, 2, step_k> fa(scale, softcap); + FlashAttn<576, 512, 2, step_k> fa(scale, softcap, sinkf); fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); if (update(2*n_step)) return; } - FlashAttn<576, 512, 1, step_k> fa(scale, softcap); + FlashAttn<576, 512, 1, step_k> fa(scale, softcap, sinkf); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); } @@ -51,37 +52,37 @@ template inline bool iqk_deepseek_helper(ggml_type type_k, int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, - float scale, float softcap, float * qkv, float * M, float * S) { + float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) { if (type_k == GGML_TYPE_Q8_0) { HelperQ80 kh((const char *)k, stride_k); HelperQ80 vh((const char *)v, stride_v); - iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); return true; } if (type_k == GGML_TYPE_Q8_0_R8) { HelperQ80R8<576> kh((const char *)k, stride_k); HelperQ80 vh((const char *)v, stride_v); - iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); return true; } if (type_k == GGML_TYPE_Q6_0) { HelperQ60 kh((const char *)k, stride_k); HelperQ60 vh((const char *)v, stride_v); - iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); return true; } #if GGML_IQK_FA_ALL_QUANTS if (type_k == GGML_TYPE_Q8_KV) { HelperQ8KV<576> kh((const char *)k, stride_k); HelperQ8KV<512> vh((const char *)v, stride_v); - iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); return true; } #endif if (type_k == GGML_TYPE_F16) { HelperF16 kh((const char *)k, stride_k); HelperF16 vh((const char *)v, stride_v); - iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); return true; } #ifdef __AVX512BF16__ @@ -89,10 +90,10 @@ inline bool iqk_deepseek_helper(ggml_type type_k, HelperBF16<576, step_k> kh((const char *)k, stride_k); HelperBF16<512, step_k> vh((const char *)v, stride_v); if (nq1 % 8 == 0) { - FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap); + FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap, sinkf); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } else { - FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap); + FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap, sinkf); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } return true; @@ -113,7 +114,7 @@ IQK_FA_CASE(iqk_fa_576_512) { } stride_q /= sizeof(float); // q stride as float return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S); + q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, sinkf, M, S); } diff --git a/ggml/src/iqk/fa/iqk_fa_64_64.cpp b/ggml/src/iqk/fa/iqk_fa_64_64.cpp index 652f682ba..84b9bad02 100644 --- a/ggml/src/iqk/fa/iqk_fa_64_64.cpp +++ b/ggml/src/iqk/fa/iqk_fa_64_64.cpp @@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_64_64) { if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types if (nk%64 == 0) { iqk_flash_helper_T<64, 64, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } iqk_flash_helper_T<64, 64, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } #endif if (nk%128 == 0) { return iqk_flash_helper_T<64, 64, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } if (nk%64 == 0) { return iqk_flash_helper_T<64, 64, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } return iqk_flash_helper_T<64, 64, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } diff --git a/ggml/src/iqk/fa/iqk_fa_96_96.cpp b/ggml/src/iqk/fa/iqk_fa_96_96.cpp index fed49cb04..44544f8ba 100644 --- a/ggml/src/iqk/fa/iqk_fa_96_96.cpp +++ b/ggml/src/iqk/fa/iqk_fa_96_96.cpp @@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_96_96) { if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types if (nk%64 == 0) { iqk_flash_helper_T<96, 96, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } iqk_flash_helper_T<96, 96, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); return true; } #endif if (nk%128 == 0) { return iqk_flash_helper_T<96, 96, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } if (nk%64 == 0) { return iqk_flash_helper_T<96, 96, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } return iqk_flash_helper_T<96, 96, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, ck, cv, cm, scale, softcap, qkv, M, S); + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); } diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index 6de2acea3..1971c4729 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -1141,10 +1141,25 @@ struct FlashQKV { } template - inline void normalize_and_store_1row(const FMS& fms, int j, const qkv_cache_t * R, float * qkv) const { + inline void normalize_and_store_1row(const FMS& fms, int j, qkv_cache_t * R, float * qkv, const float * sinkf) const { static_assert(q_step == FMS::q_step); - GGML_ASSERT(fms.S[j] > 0); - auto norm = F16::set1(1/fms.S[j]); + float S = fms.S[j]; + if (sinkf) { + float s = *sinkf; + if (s > fms.M[j]) { + float m = expf(fms.M[j] - s); + auto vm = F16::set1(m); + for (int i = 0; i < D/F16::block_size; ++i) { + auto Ri = R + F16::block_size*i; + F16::store(Ri, F16::mul(vm, F16::load(Ri))); + } + S = S*m + 1; + } else { + S += expf(s - fms.M[j]); + } + } + GGML_ASSERT(S > 0); + auto norm = F16::set1(1/S); //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); for (int i = 0; i < D/F16::block_size; ++i) { auto r = F16::load(R + F16::block_size*i); @@ -1153,7 +1168,7 @@ struct FlashQKV { } template - inline void normalize_and_store(const FMS& fms, int nq1, int stride_qkv, float * qkv, float * M, float * S) const { + inline void normalize_and_store(const FMS& fms, int nq1, int stride_qkv, float * qkv, const float * sinkf, float * M, float * S) { static_assert(q_step == FMS::q_step); if (M && S) { std::memcpy(M, fms.M, nq1*sizeof(float)); @@ -1173,7 +1188,7 @@ struct FlashQKV { } else { auto R = qkv_cache; for (int j = 0; j < nq1; ++j) { - normalize_and_store_1row(fms, j, R, qkv); + normalize_and_store_1row(fms, j, R, qkv, sinkf); qkv += stride_qkv; R += D; } @@ -1181,7 +1196,7 @@ struct FlashQKV { } template - inline void normalize_and_store(const FMS& fms, int stride_qkv, float * qkv, float * M, float * S) const { + inline void normalize_and_store(const FMS& fms, int stride_qkv, float * qkv, const float * sinkf, float * M, float * S) { static_assert(q_step == FMS::q_step); if (M && S) { std::memcpy(M, fms.M, q_step*sizeof(float)); @@ -1201,7 +1216,7 @@ struct FlashQKV { } else { auto R = qkv_cache; for (int j = 0; j < q_step; ++j) { - normalize_and_store_1row(fms, j, R, qkv); + normalize_and_store_1row(fms, j, R, qkv, sinkf); qkv += stride_qkv; R += D; } @@ -1332,7 +1347,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in FlashMS& fms, FlashQKV& fqkv, const float * q, const char * mask, float * qkv, - float * M, float * S) { + const float * sinkf, float * M, float * S) { #ifdef __aarch64__ float16_t q_f16[Dk*q_step]; #endif @@ -1356,7 +1371,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in vh.next_block(k_step); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); + fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S); q += q_step*stride_q; mask += q_step*stride_m; @@ -1383,7 +1398,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in vh.next_block(k_step); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S); } } @@ -1392,7 +1407,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, FlashMS& fms, FlashQKV& fqkv, const float * q, const char * mask, float * qkv, - float * M, float * S, char * qptr) { + const float * sinkf, float * M, float * S, char * qptr) { auto q8 = (typename KHelper::block_q8 *)qptr; if constexpr (q_step > 1 && std::is_same_v) { if (nq1 == q_step) { @@ -1412,7 +1427,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, vh.next_block(k_step); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); + fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S); return; } } @@ -1449,10 +1464,10 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, } #if FA_TIMING t1 = Perf::cur_time(); - fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); + fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S); perf.accum_nolock(3, t1); #else - fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); + fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S); #endif q += q_step*stride_q; @@ -1474,7 +1489,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, vh.next_block(k_step); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S); } #if FA_TIMING Perf::instance().add(perf); @@ -1504,7 +1519,7 @@ struct FlashAttn { static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); - FlashAttn(float scale, float softcap) : fms(scale, softcap) {} + FlashAttn(float scale, float softcap, const float * sinkf) : fms(scale, softcap), sinkf(sinkf) {} template void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, @@ -1533,7 +1548,7 @@ struct FlashAttn { HelperQ80R8 khr4(nk1, kh); #endif compute_helper_q, VHelper, FlashQKfp32>( - khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr); return; } @@ -1547,29 +1562,30 @@ struct FlashAttn { HelperQ8KVR8 khr4(nk1, kh); #endif compute_helper_q, VHelper, FlashQKfp32>( - khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr); return; } #endif } compute_helper_q>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr); } else { typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)]; compute_helper_q>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, (char *)q8); + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, (char *)q8); } } else { compute_helper>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S); } } FlashMS fms; FlashQKV fqkv; + const float * sinkf; }; @@ -1927,7 +1943,7 @@ struct FlashAttnBF16 { static_assert(k_step%32 == 0); static_assert(q_step <= 4 || q_step%4 == 0); - FlashAttnBF16(float scale, float softcap) : fms(scale, softcap) {} + FlashAttnBF16(float scale, float softcap, const float * sinkf) : fms(scale, softcap), sinkf(sinkf) {} template void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, @@ -1967,7 +1983,7 @@ struct FlashAttnBF16 { #if FA_TIMING t1 = Perf::cur_time(); #endif - fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); + fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S); #if FA_TIMING perf.accum_nolock(4, t1); #endif @@ -1990,7 +2006,7 @@ struct FlashAttnBF16 { vh.next_block(k_step); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S); } #if FA_TIMING Perf::instance().add(perf); @@ -1999,12 +2015,14 @@ struct FlashAttnBF16 { FlashMS fms; FlashQKV fqkv; + const float * sinkf; }; #endif template inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, - const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { + const float * q, const char * mask, float scale, float softcap, float * qkv, + const float * sinkf, float * M, float * S) { auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) { nq1 -= n; @@ -2018,48 +2036,48 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str if (nk1 >= 512) { if (nq1 >= 128) { int n_step = nq1/128; - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap, sinkf); fa.compute(kh, vh, 128*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); if (update(128*n_step)) return; } if (nq1 >= 64) { int n_step = nq1/64; - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap, sinkf); fa.compute(kh, vh, 64*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); if (update(64*n_step)) return; } if (nq1 >= 32) { int n_step = nq1/32; - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap, sinkf); fa.compute(kh, vh, 32*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); if (update(32*n_step)) return; } if (nq1 >= 16) { int n_step = nq1/16; - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap, sinkf); fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); if (update(16*n_step)) return; } } if (nq1 >= 8) { int n_step = nq1/8; - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap, sinkf); fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); if (update(8*n_step)) return; } else if (nq1 >= 4) { int n_step = nq1/4; - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap, sinkf); fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); if (update(4*n_step)) return; } else if (nq1 >= 2) { int n_step = nq1/2; - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap, sinkf); fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); if (update(2*n_step)) return; } - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap, sinkf); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } @@ -2067,26 +2085,26 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str template inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, - float scale, float softcap, float * qkv, float * M, float * S) { + float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) { HelperBF16 kh(k, stride_k); HelperBF16 vh(v, stride_v); if (nk1 >= 4096) { if (nq1 >= 64) { - FlashAttnBF16 fa(scale, softcap); + FlashAttnBF16 fa(scale, softcap, sinkf); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); return; } else if (nq1 >= 16) { - FlashAttnBF16 fa(scale, softcap); + FlashAttnBF16 fa(scale, softcap, sinkf); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); return; } } if (nq1 >= 8) { - FlashAttnBF16 fa(scale, softcap); + FlashAttnBF16 fa(scale, softcap, sinkf); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } else { - FlashAttnBF16 fa(scale, softcap); + FlashAttnBF16 fa(scale, softcap, sinkf); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } } @@ -2096,43 +2114,43 @@ template inline bool iqk_flash_helper_T(KHelper& kh, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv, const float * q, const char * v, const char * mask, - float scale, float softcap, float * qkv, float * M, float * S) { + float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) { switch (type_v) { case GGML_TYPE_F16: { HelperF16 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); } break; #ifdef __AVX512BF16__ case GGML_TYPE_BF16: { HelperBF16 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); } break; #endif case GGML_TYPE_Q8_0: { HelperQ80 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_Q8_KV: { HelperQ8KV vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_Q6_0: { HelperQ60 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); } break; #if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_Q4_1: { HelperQ41 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_IQ4_NL: { HelperIQ4nl vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S); } break; #endif default: return false; @@ -2144,42 +2162,42 @@ template inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, - float scale, float softcap, float * qkv, float * M, float * S) { + float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) { bool result = false; switch (type_k) { case GGML_TYPE_F16: { HelperF16 kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_Q8_0: { HelperQ80 kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_Q8_0_R8: { HelperQ80R8 kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_Q6_0: { HelperQ60 kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S); } break; #if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q8_KV: { HelperQ8KV kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_Q4_0: { HelperQ40 kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_Q4_1: { HelperQ41 kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S); } break; case GGML_TYPE_IQ4_NL: { HelperIQ4nl kh(k, stride_k); - result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + result = iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S); } break; #endif default: break; @@ -2194,7 +2212,7 @@ inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,\ const float * q, const void * k, const void * v, const void * mask,\ float scale, float softcap,\ - float * qkv, float * M, float * S) + float * qkv, const float * sinkf, float * M, float * S) IQK_FA_CASE(iqk_fa_576_512); IQK_FA_CASE(iqk_fa_192_128); diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index 9a974ae74..19e6cd25d 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -66,6 +66,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements float scale, // scale applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv, // v*softmax(scale*(k*q)) @@ -139,7 +140,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float auto work_this_thread = (float *)(result_buffer + ith*size_thread); if (!iqk_flash_attn_impl(int_type_k, int_type_v, Dk, Dv, nq_this_thread, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv, - (const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, + (const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, nullptr, 0, scale, softcap, work_this_thread, work_this_thread + (Dv+0)*nq_this_thread, work_this_thread + (Dv+1)*nq_this_thread)) return false; @@ -182,51 +183,6 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) { auto result_size = (Dv + 16)*rk2*sizeof(float); int gcd = simple_gcd(nek2, nth); - if (false && gcd > 1) { - int nth_g = nth/gcd; - int ith_g = ith%nth_g; - int nek1_32 = nek1/32; - int nek1_pt = (nek1_32 + nth_g - 1)/nth_g; - int ith_mid = nth_g; - if (nek1_pt*nth_g > nek1_32) { - ith_mid = nek1_32 - nth_g*(nek1_pt - 1); - } - nek1_pt *= 32; - int nek1_mid = ith_mid*nek1_pt; - int nek1_thread = ith_g < ith_mid ? nek1_pt : nek1_pt - 32; - for (int ik02 = ith/nth_g; ik02 < nek2; ik02 += gcd) { - int ik01 = ith_g < ith_mid ? ith_g*nek1_pt : nek1_mid + (ith_g - ith_mid)*nek1_thread; - auto this_result = (float *)((char *)work_buffer + (ik02*nth_g + ith_g)*result_size); - auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2); - auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2; - auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2; - auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here - if (!iqk_flash_attn_impl(int_type_k, int_type_v, - Dk, Dv, rk2, nek1_thread, nbq2, stride_k, stride_v, 0, Dv, - this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, - scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false; - } - - barrier(barrier_data); - - for (int iq2 = ith; iq2 < neq2; iq2 += nth) { - int ik02 = iq2/rk2; - int il = iq2 - ik02*rk2; - auto Racc = qkv + iq2*nb1/sizeof(float); - float M = -INFINITY, S = 0; - for (int ig = 0; ig < nth_g; ++ig) { - int istep_k = ik02*nth_g + ig; - auto this_result = (float *)((char *)work_buffer + istep_k*result_size); - const float * R = this_result + il*Dv; - const float * Mj = this_result + Dv*rk2; - const float * Sj = Mj + rk2; - accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R); - } - float norm = S > 0 ? 1/S : 1; - for (int i = 0; i < Dv; ++i) Racc[i] *= norm; - } - return true; - } int nth_k = nth/gcd; int nek2_k = nek2/gcd; int nchunk = nek2_k*nek1/32; @@ -259,7 +215,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here if (!iqk_flash_attn_impl(int_type_k, int_type_v, Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv, - this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, + this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, nullptr, 0, scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false; } @@ -281,6 +237,16 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float const float * Sj = Mj + rk2; accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R); } + if (sinks) { + float s = ((const float *)sinks)[iq2]; + if (s > M) { + float m = expf(M - s); + for (int i = 0; i < Dv; ++i) Racc[i] *= m; + S = S*m + 1; + } else { + S += expf(s - M); + } + } float norm = S > 0 ? 1/S : 1; for (int i = 0; i < Dv; ++i) Racc[i] *= norm; } @@ -306,6 +272,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float int counter = 0; for (int64_t iq3 = 0; iq3 < neq3; iq3++) { for (int64_t iq2 = 0; iq2 < neq2; iq2++) { + auto sinksf = sinks ? (const float *)sinks + iq2 : nullptr; if (counter++ % (nth/ntg) == ith/ntg) { int iq1 = (ith%ntg)*neq1g; int this_neq1 = std::min(neq1g, neq1-iq1); @@ -314,7 +281,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float (const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q), (const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3), (const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3), - (const void *)((const char *)mask + iq1*stride_m), + (const void *)((const char *)mask + iq1*stride_m), sinksf, 1, scale, softcap, (float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false; } diff --git a/ggml/src/iqk/iqk_flash_impl.h b/ggml/src/iqk/iqk_flash_impl.h index 6f62e56b1..8db97731b 100644 --- a/ggml/src/iqk/iqk_flash_impl.h +++ b/ggml/src/iqk/iqk_flash_impl.h @@ -23,6 +23,8 @@ bool iqk_flash_attn_impl(int type_k, // type of k const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + const float * sinksf, // attention sinks + int nsinks, // number of sinks float scale, // scale applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv, // v*softmax(scale*(k*q)) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index f8624bfc2..041ad1650 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -120,16 +120,21 @@ struct MulMat { funcs[n_left-1](n, vx, bx, info, nrc_x); } } - inline void gelu(int n, const float * src, float * dst); - inline void relu(int n, const float * src, float * dst); - inline void silu(int n, const float * src, float * dst); - inline void activate(ggml_unary_op op, int n, const float * src, float * dst) { + inline static void gelu(int n, const float * src, float * dst); + inline static void relu(int n, const float * src, float * dst); + inline static void silu(int n, const float * src, float * dst); + inline static void swiglu_oai(int n, const float * src, float * dst); + inline static void clamp_oai(int n, float *x); + inline static void activate(ggml_unary_op op, int n, const float * src, float * dst) { if (op == GGML_UNARY_OP_GELU) gelu(n, src, dst); else if (op == GGML_UNARY_OP_RELU) relu(n, src, dst); else if (op == GGML_UNARY_OP_SILU) silu(n, src, dst); + else if (op == GGML_UNARY_OP_SWIGLU_OAI) swiglu_oai(n, src, dst); else GGML_ABORT("fatal error"); } - inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx, DataInfo& info, int nrc_x, int nrc_y, int unary_op) { + inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx, + const float * up_b, const float * gate_b, + DataInfo& info, int nrc_x, int nrc_y, int unary_op) { #ifdef __aarch64__ constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small) #else @@ -137,6 +142,29 @@ struct MulMat { #endif auto op = ggml_unary_op(unary_op); float tmp[k_x_step*16]; + auto process = [&tmp, n, op, vx_gate, vx_up, gate_b, up_b, bx, xstep = k_x_step] (mul_mat_t func, const DataInfo& this_info, int ix, int this_nrc_x, int ny) { + func(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny; ++ky) { + if (gate_b) { + auto b = gate_b + ix; + auto x = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) x[j] += b[j]; + } + activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*xstep); + } + func(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny; ++ky) { + auto result = this_info.dst_row(ky); + if (up_b) { + auto b = up_b + ix; + for (int j = 0; j < this_nrc_x; ++j) result[j] += b[j]; + } + if (op == GGML_UNARY_OP_SWIGLU_OAI) { + clamp_oai(this_nrc_x, result); + } + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*xstep + j]; + } + }; if (func16 && nrc_y >= 16) { int n_step = (nrc_y - info.cur_y)/16; for (int ix = 0; ix < nrc_x; ix += k_x_step) { @@ -144,15 +172,7 @@ struct MulMat { this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; for (int iy = 0; iy < n_step; ++iy) { - func16(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < 16; ++ky) { - activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); - } - func16(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < 16; ++ky) { - auto result = this_info.dst_row(ky); - for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; - } + process(func16, this_info, ix, this_nrc_x, 16); this_info.cur_y += 16; } } @@ -175,23 +195,11 @@ struct MulMat { this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; for (int iy = 0; iy < my1; ++iy) { - funcs[ny1-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny1; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); - funcs[ny1-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny1; ++ky) { - auto result = this_info.dst_row(ky); - for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; - } + process(funcs[ny1-1], this_info, ix, this_nrc_x, ny1); this_info.cur_y += ny1; } for (int iy = 0; iy < my2; ++iy) { - funcs[ny2-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny2; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); - funcs[ny2-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny2; ++ky) { - auto result = this_info.dst_row(ky); - for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; - } + process(funcs[ny2-1], this_info, ix, this_nrc_x, ny2); this_info.cur_y += ny2; } } @@ -203,13 +211,7 @@ struct MulMat { this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; for (int iy = 0; iy < n_step; ++iy) { - funcs[ny-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); - funcs[ny-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < ny; ++ky) { - auto result = this_info.dst_row(ky); - for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; - } + process(funcs[ny-1], this_info, ix, this_nrc_x, ny); this_info.cur_y += ny; } } @@ -222,13 +224,7 @@ struct MulMat { auto this_info = info; this_info.s += ix; int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; - funcs[n_left-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < n_left; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); - funcs[n_left-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); - for (int ky = 0; ky < n_left; ++ky) { - auto result = this_info.dst_row(ky); - for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; - } + process(funcs[n_left-1], this_info, ix, this_nrc_x, n_left); } } } @@ -731,6 +727,7 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, int typeA, const void * Aup, const void * Agate, long strideA, int typeB, const void * B, long strideB, + const char * up_b_c, const char * gate_b_c, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping; @@ -774,7 +771,9 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n if (!iqk_convert_repack(typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) { GGML_ABORT("Fatal error"); } - mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, this_info, this_nrc_x, Ny, unary_op); + auto up_b = up_b_c ? (const float *)up_b_c + first_x + ix : nullptr; + auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x + ix : nullptr; + mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, up_b, gate_b, this_info, this_nrc_x, Ny, unary_op); } return true; @@ -795,7 +794,10 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n nrc_x *= num_rows; DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; - mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny, unary_op); + auto up_b = up_b_c ? (const float *)up_b_c + first_x : nullptr; + auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x : nullptr; + mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx, + up_b, gate_b, info, nrc_x, Ny, unary_op); return true; } @@ -993,6 +995,46 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { namespace { +// TODO: these swiglu_oai constants shouldn't be hard coded +constexpr float k_swiglu_oai_alpha = 1.702f; +constexpr float k_swiglu_oai_limit = 7.f; + +void MulMat::swiglu_oai(int n, const float * x, float * y) { +// int i = 0; +//#if defined __AVX512F__ && defined __AVX512DQ__ +// { +// auto max = _mm512_set1_ps(k_swiglu_oai_limit); +// auto alpha = _mm512_set1_ps(-k_swiglu_oai_alpha); +// for (; i + 15 < n; i += 16) { +// auto xc = v_clamp_max(_mm512_loadu_ps(x + i), max); +// _mm512_storeu_ps(y + i, v_silu_oai(xc, alpha)); +// } +// } +//#endif +//#if defined __AVX2__ && defined __FMA__ +// if (i + 7 < n) { +// auto max = _mm256_set1_ps(k_swiglu_oai_limit); +// auto alpha = _mm256_set1_ps(-k_swiglu_oai_alpha); +// for (; i + 7 < n; i += 8) { +// auto xc = v_clamp_max(_mm256_loadu_ps(x + i), max); +// _mm256_storeu_ps(y + i, v_silu_oai(xc, alpha)); +// } +// } +//#endif +// for (; i < n; ++i) { +// auto xi = std::min(x[i], k_swiglu_oai_limit); +// y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha)); +// } + for (int i = 0; i < n; ++i) { + auto xi = std::min(x[i], k_swiglu_oai_limit); + y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha)); + } +} + +void MulMat::clamp_oai(int n, float * x) { + for (int i = 0; i < n; ++i) x[i] = 1.f + std::max(std::min(x[i], k_swiglu_oai_limit), -k_swiglu_oai_limit); +} + #if defined(__ARM_NEON) && defined(__aarch64__) void MulMat::gelu(int n, const float * x, float * y) { constexpr float GELU_COEF_A = 0.044715f; @@ -1040,6 +1082,37 @@ void MulMat::gelu(int n, const float * x, float * y) { for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i]))); } +//void MulMat::swiglu_oai(int n, const float * x, float * y) { +// int i = 0; +//#if defined __AVX512F__ && defined __AVX512DQ__ +// { +// auto limit = _mm512_set1_ps(k_swiglu_oai_limit); +// auto alpha = _mm512_set1_ps(k_swiglu_oai_alpha); +// for (; i + 15 < n; i += 16) { +// auto xi = _mm512_loadu_ps(x + i); +// auto mask = _mm512_cmp +// +// } +// __m512 c1 = _mm512_set1_ps(GELU_COEF_A); +// __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI); +// for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_gelu(_mm512_loadu_ps(x + i), c1, c2)); +// } +//#endif +//#if defined __AVX2__ && defined __FMA__ +// if (i + 7 < n) { +// __m256 c1 = _mm256_set1_ps(GELU_COEF_A); +// __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI); +// for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_gelu(_mm256_loadu_ps(x + i), c1, c2)); +// +// } +//#endif +// for (; i < n; ++i) { +// auto xi = std::min(x[i], k_swiglu_oai_limit); +// y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha)); +// } +//} + + void MulMat::silu(int n, const float * x, float * y) { int i = 0; #if defined __AVX512F__ && defined __AVX512DQ__ @@ -1188,6 +1261,8 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + const float * sinksf, // mask. If not null, assumed to be fp16. nq x nk elements + [[maybe_unused]] int nsinks, float scale, // scale applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv, // v*softmax(scale*(k*q)) @@ -1197,32 +1272,32 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k if (Dk == 576 && Dv == 512) { return iqk_fa_576_512(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, k, v, mask, scale, softcap, qkv, M, S); + q, k, v, mask, scale, softcap, qkv, sinksf, M, S); } if (Dk == 192 && Dv == 128) { return iqk_fa_192_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, k, v, mask, scale, softcap, qkv, M, S); + q, k, v, mask, scale, softcap, qkv, sinksf, M, S); } if (Dk == 256 && Dv == 256) { return iqk_fa_256_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, k, v, mask, scale, softcap, qkv, M, S); + q, k, v, mask, scale, softcap, qkv, sinksf, M, S); } if (Dk == 128 && Dv == 128) { return iqk_fa_128_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, k, v, mask, scale, softcap, qkv, M, S); + q, k, v, mask, scale, softcap, qkv, sinksf, M, S); } if (Dk == 96 && Dv == 96) { return iqk_fa_96_96(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, k, v, mask, scale, softcap, qkv, M, S); + q, k, v, mask, scale, softcap, qkv, sinksf, M, S); } if (Dk == 64 && Dv == 64) { return iqk_fa_64_64(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, k, v, mask, scale, softcap, qkv, M, S); + q, k, v, mask, scale, softcap, qkv, sinksf, M, S); } return false; diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index 87722f6fe..b131095b0 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -32,6 +32,7 @@ IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, int typeA, const void * Aup, const void * Agate, long strideA, int typeB, const void * B, long strideB, + const char * up_b, const char * gate_b, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); IQK_API int iqk_dequant_type(int type, int Ny); @@ -57,6 +58,7 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements float scale, // scale applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv, // v*softmax(scale*(k*q)) diff --git a/ggml/src/iqk/iqk_utils.h b/ggml/src/iqk/iqk_utils.h index 194bf9b81..435ae4dd2 100644 --- a/ggml/src/iqk/iqk_utils.h +++ b/ggml/src/iqk/iqk_utils.h @@ -61,6 +61,13 @@ static inline float32x4_t v_silu(float32x4_t x) { const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); return vdivq_f32(x, one_plus_exp_neg_x); } +static inline float32x4_t v_silu_oai(float32x4_t x, float32x4_t alpha) { + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t neg_x = vmulq_f32(alpha, x); + const float32x4_t exp_neg_x = v_expf(neg_x); + const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); + return vdivq_f32(x, one_plus_exp_neg_x); +} static inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) { const float32x4_t one = vdupq_n_f32(1.0f); float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x)); @@ -131,6 +138,17 @@ static inline __m512 v_silu(__m512 x) { const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); return _mm512_div_ps(x, one_plus_exp_neg_x); } +static inline __m512 v_silu_oai(__m512 x, __m512 alpha) { + const __m512 one = _mm512_set1_ps(1); + const __m512 neg_x = _mm512_mul_ps(alpha, x); + const __m512 exp_neg_x = v_expf(neg_x); + const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); + return _mm512_div_ps(x, one_plus_exp_neg_x); +} +static inline __m512 v_clamp_max(__m512 x, __m512 max) { + auto mask = _mm512_cmp_ps_mask(x, max, _CMP_GT_OQ); + return _mm512_mask_blend_ps(mask, x, max); +} #endif // __AVX512__ #if defined(__AVX2__) && defined(__FMA__) @@ -195,12 +213,23 @@ static inline __m256 v_gelu(__m256 x, __m256 c1, __m256 c2) { } static inline __m256 v_silu(__m256 x) { const __m256 one = _mm256_set1_ps(1); - const __m256 zero = _mm256_setzero_ps(); + const __m256 zero = _mm256_setzero_ps(); const __m256 neg_x = _mm256_sub_ps(zero, x); const __m256 exp_neg_x = v_expf(neg_x); const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); return _mm256_div_ps(x, one_plus_exp_neg_x); } +static inline __m256 v_silu_oai(__m256 x, __m256 alpha) { + const __m256 one = _mm256_set1_ps(1); + const __m256 neg_x = _mm256_mul_ps(alpha, x); + const __m256 exp_neg_x = v_expf(neg_x); + const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); + return _mm256_div_ps(x, one_plus_exp_neg_x); +} +static inline __m256 v_clamp_max(__m256 x, __m256 max) { + auto mask = _mm256_cmp_ps(x, max, _CMP_GT_OQ); + return _mm256_or_ps(_mm256_and_ps(mask, max), _mm256_andnot_ps(mask, x)); +} #endif // __AVX2__ diff --git a/include/llama.h b/include/llama.h index c68fa2298..482f02b15 100644 --- a/include/llama.h +++ b/include/llama.h @@ -70,50 +70,52 @@ extern "C" { typedef int32_t llama_seq_id; enum llama_vocab_type { - LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab - LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback - LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE - LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece - LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram + LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab + LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback + LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE + LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece + LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram + LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization + LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming }; // pre-tokenization types - enum llama_vocab_pre_type { - LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, - LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, - LLAMA_VOCAB_PRE_TYPE_FALCON = 4, - LLAMA_VOCAB_PRE_TYPE_MPT = 5, - LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, - LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, - LLAMA_VOCAB_PRE_TYPE_REFACT = 8, - LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, - LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, - LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, - LLAMA_VOCAB_PRE_TYPE_OLMO = 12, - LLAMA_VOCAB_PRE_TYPE_DBRX = 13, - LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, - LLAMA_VOCAB_PRE_TYPE_PORO = 15, - LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, - LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, - LLAMA_VOCAB_PRE_TYPE_VIKING = 18, - LLAMA_VOCAB_PRE_TYPE_JAIS = 19, - LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, - LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, - LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, //llama.cpp lists this as 28 - LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, - LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, - LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, - LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, - LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, - LLAMA_VOCAB_PRE_TYPE_FALCON_3 = 34, - LLAMA_VOCAB_PRE_TYPE_FALCON_E = 35, - LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 36, //llama.cpp lists this as 35 - LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 37, //llama.cpp lists this as 36 - LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 38, //llama.cpp lists this as 37 - }; + //enum llama_vocab_pre_type { + // LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, + // LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, + // LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, + // LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, + // LLAMA_VOCAB_PRE_TYPE_FALCON = 4, + // LLAMA_VOCAB_PRE_TYPE_MPT = 5, + // LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, + // LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, + // LLAMA_VOCAB_PRE_TYPE_REFACT = 8, + // LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, + // LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, + // LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, + // LLAMA_VOCAB_PRE_TYPE_OLMO = 12, + // LLAMA_VOCAB_PRE_TYPE_DBRX = 13, + // LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, + // LLAMA_VOCAB_PRE_TYPE_PORO = 15, + // LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, + // LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, + // LLAMA_VOCAB_PRE_TYPE_VIKING = 18, + // LLAMA_VOCAB_PRE_TYPE_JAIS = 19, + // LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, + // LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, + // LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + // LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, //llama.cpp lists this as 28 + // LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, + // LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + // LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + // LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + // LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, + // LLAMA_VOCAB_PRE_TYPE_FALCON_3 = 34, + // LLAMA_VOCAB_PRE_TYPE_FALCON_E = 35, + // LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 36, //llama.cpp lists this as 35 + // LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 37, //llama.cpp lists this as 36 + // LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 38, //llama.cpp lists this as 37 + //}; // note: these values should be synchronized with ggml_rope // TODO: maybe move this enum to ggml.h (ggml_rope_type) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f86137c43..0d5ab7357 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -17,6 +17,8 @@ add_library(llama llama-vocab.cpp llama-grammar.cpp llama-sampling.cpp + llama-mmap.cpp + llama-model-loader.cpp unicode.h unicode.cpp unicode-data.cpp diff --git a/src/llama-arch.h b/src/llama-arch.h new file mode 100644 index 000000000..76ac44c48 --- /dev/null +++ b/src/llama-arch.h @@ -0,0 +1,288 @@ +#pragma once + +#include + +enum llm_arch { + LLM_ARCH_LLAMA, + LLM_ARCH_LLAMA4, + LLM_ARCH_DECI, + LLM_ARCH_FALCON, + LLM_ARCH_BAICHUAN, + LLM_ARCH_GROK, + LLM_ARCH_GPT2, + LLM_ARCH_GPTJ, + LLM_ARCH_GPTNEOX, + LLM_ARCH_MPT, + LLM_ARCH_STARCODER, + LLM_ARCH_REFACT, + LLM_ARCH_BERT, + LLM_ARCH_NOMIC_BERT, + LLM_ARCH_JINA_BERT_V2, + LLM_ARCH_BLOOM, + LLM_ARCH_STABLELM, + LLM_ARCH_QWEN, + LLM_ARCH_QWEN2, + LLM_ARCH_QWEN2MOE, + LLM_ARCH_QWEN3, + LLM_ARCH_QWEN3MOE, + LLM_ARCH_PHI2, + LLM_ARCH_PHI3, + LLM_ARCH_PLAMO, + LLM_ARCH_CODESHELL, + LLM_ARCH_ORION, + LLM_ARCH_INTERNLM2, + LLM_ARCH_MINICPM, + LLM_ARCH_GEMMA, + LLM_ARCH_GEMMA2, + LLM_ARCH_GEMMA3, + LLM_ARCH_STARCODER2, + LLM_ARCH_MAMBA, + LLM_ARCH_XVERSE, + LLM_ARCH_COMMAND_R, + LLM_ARCH_DBRX, + LLM_ARCH_OLMO, + LLM_ARCH_OPENELM, + LLM_ARCH_ARCTIC, + LLM_ARCH_DEEPSEEK2, + LLM_ARCH_CHATGLM, + LLM_ARCH_GLM4, + LLM_ARCH_GLM4_MOE, + LLM_ARCH_BITNET, + LLM_ARCH_BITNET_25, + LLM_ARCH_BITNET_B158, + LLM_ARCH_T5, + LLM_ARCH_T5ENCODER, + LLM_ARCH_JAIS, + LLM_ARCH_GRANITE, + LLM_ARCH_GRANITE_MOE, + LLM_ARCH_COHERE2, + LLM_ARCH_DOTS1, + LLM_ARCH_HUNYUAN_MOE, + LLM_ARCH_OPENAI_MOE, + LLM_ARCH_UNKNOWN, +}; + +enum llm_kv { + LLM_KV_GENERAL_TYPE, + LLM_KV_GENERAL_ARCHITECTURE, + LLM_KV_GENERAL_QUANTIZATION_VERSION, + LLM_KV_GENERAL_ALIGNMENT, + LLM_KV_GENERAL_NAME, + LLM_KV_GENERAL_AUTHOR, + LLM_KV_GENERAL_VERSION, + LLM_KV_GENERAL_URL, + LLM_KV_GENERAL_DESCRIPTION, + LLM_KV_GENERAL_LICENSE, + LLM_KV_GENERAL_SOURCE_URL, + LLM_KV_GENERAL_SOURCE_HF_REPO, + + LLM_KV_VOCAB_SIZE, + LLM_KV_CONTEXT_LENGTH, + LLM_KV_EMBEDDING_LENGTH, + LLM_KV_BLOCK_COUNT, + LLM_KV_LEADING_DENSE_BLOCK_COUNT, + LLM_KV_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, + LLM_KV_USE_PARALLEL_RESIDUAL, + LLM_KV_TENSOR_DATA_LAYOUT, + LLM_KV_EXPERT_COUNT, + LLM_KV_EXPERT_USED_COUNT, + LLM_KV_EXPERT_SHARED_COUNT, + LLM_KV_EXPERT_WEIGHTS_SCALE, + LLM_KV_EXPERT_WEIGHTS_NORM, + LLM_KV_EXPERT_GATING_FUNC, + LLM_KV_NEXTN_PREDICT_LAYERS, + LLM_KV_POOLING_TYPE, + LLM_KV_LOGIT_SCALE, + LLM_KV_DECODER_START_TOKEN_ID, + LLM_KV_ATTN_LOGIT_SOFTCAPPING, + LLM_KV_FINAL_LOGIT_SOFTCAPPING, + LLM_KV_SWIN_NORM, + LLM_KV_RESCALE_EVERY_N_LAYERS, + LLM_KV_TIME_MIX_EXTRA_DIM, + LLM_KV_TIME_DECAY_EXTRA_DIM, + LLM_KV_RESIDUAL_SCALE, + LLM_KV_EMBEDDING_SCALE, + LLM_KV_TOKEN_SHIFT_COUNT, + LLM_KV_INTERLEAVE_MOE_LAYER_STEP, + + LLM_KV_ATTENTION_HEAD_COUNT, + LLM_KV_ATTENTION_HEAD_COUNT_KV, + LLM_KV_ATTENTION_MAX_ALIBI_BIAS, + LLM_KV_ATTENTION_CLAMP_KQV, + LLM_KV_ATTENTION_KEY_LENGTH, + LLM_KV_ATTENTION_VALUE_LENGTH, + LLM_KV_ATTENTION_LAYERNORM_EPS, + LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, + LLM_KV_ATTENTION_CAUSAL, + LLM_KV_ATTENTION_Q_LORA_RANK, + LLM_KV_ATTENTION_KV_LORA_RANK, + LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, + LLM_KV_ATTENTION_SLIDING_WINDOW, + LLM_KV_ATTENTION_SCALE, + + LLM_KV_ROPE_DIMENSION_COUNT, + LLM_KV_ROPE_FREQ_BASE, + LLM_KV_ROPE_SCALE_LINEAR, + LLM_KV_ROPE_SCALING_TYPE, + LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_ATTN_FACTOR, + LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, + LLM_KV_ROPE_SCALING_FINETUNED, + LLM_KV_ROPE_SCALING_YARN_LOG_MUL, + + LLM_KV_SPLIT_NO, + LLM_KV_SPLIT_COUNT, + LLM_KV_SPLIT_TENSORS_COUNT, + + LLM_KV_SSM_INNER_SIZE, + LLM_KV_SSM_CONV_KERNEL, + LLM_KV_SSM_STATE_SIZE, + LLM_KV_SSM_TIME_STEP_RANK, + + LLM_KV_TOKENIZER_MODEL, + LLM_KV_TOKENIZER_PRE, + LLM_KV_TOKENIZER_LIST, + LLM_KV_TOKENIZER_TOKEN_TYPE, + LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, + LLM_KV_TOKENIZER_SCORES, + LLM_KV_TOKENIZER_MERGES, + LLM_KV_TOKENIZER_BOS_ID, + LLM_KV_TOKENIZER_EOS_ID, + LLM_KV_TOKENIZER_UNK_ID, + LLM_KV_TOKENIZER_SEP_ID, + LLM_KV_TOKENIZER_PAD_ID, + LLM_KV_TOKENIZER_CLS_ID, + LLM_KV_TOKENIZER_MASK_ID, + LLM_KV_TOKENIZER_ADD_BOS, + LLM_KV_TOKENIZER_ADD_EOS, + LLM_KV_TOKENIZER_ADD_SEP, + LLM_KV_TOKENIZER_ADD_PREFIX, + LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, + LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, + LLM_KV_TOKENIZER_HF_JSON, + LLM_KV_TOKENIZER_RWKV, + LLM_KV_TOKENIZER_CHAT_TEMPLATE, + LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, + LLM_KV_TOKENIZER_FIM_PRE_ID, + LLM_KV_TOKENIZER_FIM_SUF_ID, + LLM_KV_TOKENIZER_FIM_MID_ID, + LLM_KV_TOKENIZER_FIM_PAD_ID, + LLM_KV_TOKENIZER_FIM_REP_ID, + LLM_KV_TOKENIZER_FIM_SEP_ID, + LLM_KV_TOKENIZER_PREFIX_ID, + LLM_KV_TOKENIZER_SUFFIX_ID, + LLM_KV_TOKENIZER_MIDDLE_ID, + LLM_KV_TOKENIZER_EOT_ID, + LLM_KV_TOKENIZER_EOM_ID, + + LLM_KV_ADAPTER_TYPE, + LLM_KV_ADAPTER_LORA_ALPHA, +}; + +struct LLM_KV { + LLM_KV(llm_arch arch, const char* suffix = nullptr); + + llm_arch arch; + const char* suffix; + std::string operator()(llm_kv kv) const; +}; + +enum llm_tensor { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_TOKEN_TYPES, + LLM_TENSOR_POS_EMBD, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_NORM_2, + LLM_TENSOR_ATTN_OUT_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_ATTN_SINKS, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_ACT, + LLM_TENSOR_FFN_DOWN_EXP, // split experts for backward compatibility + LLM_TENSOR_FFN_GATE_EXP, + LLM_TENSOR_FFN_UP_EXP, + LLM_TENSOR_FFN_NORM_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, // merged experts + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_SSM_IN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_X, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_OUT, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_SUB_NORM, + LLM_TENSOR_FFN_SUB_NORM, + LLM_TENSOR_DEC_ATTN_NORM, + LLM_TENSOR_DEC_ATTN_Q, + LLM_TENSOR_DEC_ATTN_K, + LLM_TENSOR_DEC_ATTN_V, + LLM_TENSOR_DEC_ATTN_OUT, + LLM_TENSOR_DEC_ATTN_REL_B, + LLM_TENSOR_DEC_CROSS_ATTN_NORM, + LLM_TENSOR_DEC_CROSS_ATTN_Q, + LLM_TENSOR_DEC_CROSS_ATTN_K, + LLM_TENSOR_DEC_CROSS_ATTN_V, + LLM_TENSOR_DEC_CROSS_ATTN_OUT, + LLM_TENSOR_DEC_CROSS_ATTN_REL_B, + LLM_TENSOR_DEC_FFN_NORM, + LLM_TENSOR_DEC_FFN_GATE, + LLM_TENSOR_DEC_FFN_DOWN, + LLM_TENSOR_DEC_FFN_UP, + LLM_TENSOR_DEC_OUTPUT_NORM, + LLM_TENSOR_ENC_ATTN_NORM, + LLM_TENSOR_ENC_ATTN_Q, + LLM_TENSOR_ENC_ATTN_K, + LLM_TENSOR_ENC_ATTN_V, + LLM_TENSOR_ENC_ATTN_OUT, + LLM_TENSOR_ENC_ATTN_REL_B, + LLM_TENSOR_ENC_FFN_NORM, + LLM_TENSOR_ENC_FFN_GATE, + LLM_TENSOR_ENC_FFN_DOWN, + LLM_TENSOR_ENC_FFN_UP, + LLM_TENSOR_ENC_OUTPUT_NORM, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, +}; + +llm_arch llm_arch_from_string(const std::string & name); diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index b123d7331..c5c29d217 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -486,9 +486,9 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; - const std::string & piece = vocab->cache_token_to_piece.at(id); + const std::string & piece = vocab->token_to_piece(id); - if (llama_token_is_eog_impl(*vocab, id)) { + if (vocab->is_eog(id)) { if (!allow_eog) { candidates->data[i].logit = -INFINITY; } @@ -511,7 +511,7 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) { const int64_t t_start_sample_us = ggml_time_us(); - if (llama_token_is_eog_impl(*vocab, token)) { + if (vocab->is_eog(token)) { for (const auto & stack : grammar->stacks) { if (stack.empty()) { return; @@ -520,7 +520,7 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc GGML_ABORT("fatal error"); } - const std::string & piece = vocab->cache_token_to_piece.at(token); + const std::string & piece = vocab->token_to_piece(token); // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece, grammar->partial_utf8); diff --git a/src/llama-impl.h b/src/llama-impl.h index a50f60cfd..cd4e0730a 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -10,6 +10,11 @@ #define LLAMA_API_INTERNAL #include "llama.h" #include +#include +#include +#include +#include +#include #ifdef __GNUC__ #ifdef __MINGW32__ @@ -33,6 +38,7 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void * #define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) #define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) #define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define LLAMA_LOG_DEBUG(...) llama_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) // // helpers @@ -166,3 +172,49 @@ struct ring_buffer { size_t pos = 0; std::vector data; }; + +LLAMA_ATTRIBUTE_FORMAT(1, 2) +static std::string format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + +static std::string llama_format_tensor_shape(const std::vector & ne) { + char buf[256]; + snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0)); + for (size_t i = 1; i < ne.size(); i++) { + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i)); + } + return buf; +} + +static std::string llama_format_tensor_shape(const struct ggml_tensor * t) { + char buf[256]; + snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]); + } + return buf; +} + +template +struct no_init { + T value; + no_init() { /* do nothing */ } +}; + + +struct gguf_context; +std::string gguf_kv_to_str(const gguf_context * ctx_gguf, int i); + +ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer); diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp new file mode 100644 index 000000000..4a65c9cb6 --- /dev/null +++ b/src/llama-mmap.cpp @@ -0,0 +1,650 @@ +#include "llama-mmap.h" + +#include "llama-impl.h" + +#include "ggml.h" + +#include +#include +#include +#include +#include +#include +#include + +#ifdef __has_include + #if __has_include() + #include + #if defined(_POSIX_MAPPED_FILES) + #include + #include + #endif + #if defined(_POSIX_MEMLOCK_RANGE) + #include + #endif + #endif +#endif + +#if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #ifndef NOMINMAX + #define NOMINMAX + #endif + #include + #ifndef PATH_MAX + #define PATH_MAX MAX_PATH + #endif + #include +#endif + +#if defined(__APPLE__) +#include +#endif + +// TODO: consider moving to llama-impl.h if needed in more places +#if defined(_WIN32) +static std::string llama_format_win_err(DWORD err) { + LPSTR buf; + size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL); + if (!size) { + return "FormatMessageA failed"; + } + std::string ret(buf, size); + LocalFree(buf); + return ret; +} +#endif + +// llama_file + +struct llama_file::impl { +#if defined(_WIN32) + HANDLE fp_win32; + std::string GetErrorMessageWin32(DWORD error_code) const { + std::string ret; + LPSTR lpMsgBuf = NULL; + DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL); + if (!bufLen) { + ret = format("Win32 error code: %lx", error_code); + } else { + ret = lpMsgBuf; + LocalFree(lpMsgBuf); + } + + return ret; + } + + impl(const char * fname, const char * mode) { + fp = ggml_fopen(fname, mode); + if (fp == NULL) { + throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + } + fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp)); + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + + size_t tell() const { + LARGE_INTEGER li; + li.QuadPart = 0; + BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + + return li.QuadPart; + } + + void seek(size_t offset, int whence) const { + static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN"); + static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT"); + static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END"); + + LARGE_INTEGER li; + li.QuadPart = offset; + BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + } + + void read_raw(void * ptr, size_t len) const { + size_t bytes_read = 0; + while (bytes_read < len) { + size_t chunk_size = std::min(len - bytes_read, 64*1024*1024); + DWORD chunk_read = 0; + BOOL result = ReadFile(fp_win32, reinterpret_cast(ptr) + bytes_read, chunk_size, &chunk_read, NULL); + if (!result) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + if (chunk_read < chunk_size || chunk_read == 0) { + throw std::runtime_error("unexpectedly reached end of file"); + } + + bytes_read += chunk_read; + } + } + + uint32_t read_u32() const { + uint32_t val; + read_raw(&val, sizeof(val)); + return val; + } + + void write_raw(const void * ptr, size_t len) const { + size_t bytes_written = 0; + while (bytes_written < len) { + size_t chunk_size = std::min(len - bytes_written, 64*1024*1024); + DWORD chunk_written = 0; + BOOL result = WriteFile(fp_win32, reinterpret_cast(ptr) + bytes_written, chunk_size, &chunk_written, NULL); + if (!result) { + throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + if (chunk_written < chunk_size || chunk_written == 0) { + throw std::runtime_error("unexpectedly failed to write bytes"); + } + + bytes_written += chunk_written; + } + } + + void write_u32(uint32_t val) const { + write_raw(&val, sizeof(val)); + } + + ~impl() { + if (fp) { + std::fclose(fp); + } + } +#else + impl(const char * fname, const char * mode) { + fp = ggml_fopen(fname, mode); + if (fp == NULL) { + throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + } + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + + size_t tell() const { +// TODO: this ifdef is never true? +#ifdef _WIN32 + __int64 ret = _ftelli64(fp); +#else + long ret = std::ftell(fp); +#endif + if (ret == -1) { + throw std::runtime_error(format("ftell error: %s", strerror(errno))); + } + + return (size_t) ret; + } + + void seek(size_t offset, int whence) const { +// TODO: this ifdef is never true? +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64) offset, whence); +#else + int ret = std::fseek(fp, (long) offset, whence); +#endif + if (ret != 0) { + throw std::runtime_error(format("seek error: %s", strerror(errno))); + } + } + + void read_raw(void * ptr, size_t len) const { + if (len == 0) { + return; + } + errno = 0; + std::size_t ret = std::fread(ptr, len, 1, fp); + if (ferror(fp)) { + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (ret != 1) { + throw std::runtime_error("unexpectedly reached end of file"); + } + } + + uint32_t read_u32() const { + uint32_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + + void write_raw(const void * ptr, size_t len) const { + if (len == 0) { + return; + } + errno = 0; + size_t ret = std::fwrite(ptr, len, 1, fp); + if (ret != 1) { + throw std::runtime_error(format("write error: %s", strerror(errno))); + } + } + + void write_u32(uint32_t val) const { + write_raw(&val, sizeof(val)); + } + + ~impl() { + if (fp) { + std::fclose(fp); + } + } +#endif + + FILE * fp; + size_t size; +}; + +llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique(fname, mode)) {} +llama_file::~llama_file() = default; + +size_t llama_file::tell() const { return pimpl->tell(); } +size_t llama_file::size() const { return pimpl->size; } + +int llama_file::file_id() const { +#ifdef _WIN32 + return _fileno(pimpl->fp); +#else +#if defined(fileno) + return fileno(pimpl->fp); +#else + return ::fileno(pimpl->fp); +#endif +#endif +} + +void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); } +void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); } + +uint32_t llama_file::read_u32() const { return pimpl->read_u32(); } + +void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); } +void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); } + +// llama_mmap + +struct llama_mmap::impl { +#ifdef _POSIX_MAPPED_FILES + std::vector> mapped_fragments; + + impl(struct llama_file * file, size_t prefetch, bool numa, bool use_thp) { + size = file->size(); + int fd = file->file_id(); + int flags = MAP_SHARED; + if (numa) { prefetch = 0; } +#ifdef __linux__ + if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) { + LLAMA_LOG_WARN("warning: posix_fadvise(.., POSIX_FADV_SEQUENTIAL) failed: %s\n", + strerror(errno)); + } + if (prefetch) { flags |= MAP_POPULATE; } + if (use_thp) { + size_t huge = get_default_huge_page_size(); + auto size = huge*((file->size() + huge - 1)/huge); + addr = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB, -1, 0); + if (addr != MAP_FAILED) { + printf("%s: using THP with page size %zu MiB ", __func__, huge/(1024*1024)); + fflush(stdout); + size_t tot = 0; + while (tot < file->size()) { + auto n_read = pread(fd, static_cast(addr) + tot, file->size() - tot, tot); + if (n_read < 0) throw std::runtime_error(format("Reading into mapped huge pages failed at %zu (%s)", tot, strerror(errno))); + printf("."); fflush(stdout); + tot += n_read; + } + printf(" done\n"); + mapped_fragments.emplace_back(0, file->size()); + mapped_page_size = huge; + return; + } + else { + fprintf(stderr, "%s: mmap with huge page size %zu MiB failed (%s)\n", __func__, huge/(1024*1024), strerror(errno)); + } + } +#endif + addr = mmap(NULL, file->size(), PROT_READ, flags, fd, 0); + if (addr == MAP_FAILED) { + throw std::runtime_error(format("mmap failed: %s", strerror(errno))); + } + + if (prefetch > 0) { + if (posix_madvise(addr, std::min(file->size(), prefetch), POSIX_MADV_WILLNEED)) { + LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n", + strerror(errno)); + } + } + if (numa) { + if (posix_madvise(addr, file->size(), POSIX_MADV_RANDOM)) { + LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n", + strerror(errno)); + } + } + + mapped_fragments.emplace_back(0, file->size()); + } + +#ifdef __linux__ + static int get_default_huge_page_size() { + int pg_size = 2048; + std::ifstream in("/proc/meminfo"); + if (in) { + std::string line; + while (true) { + std::getline(in, line); + if (in.fail()) break; + if (auto pos = line.find("Hugepagesize:"); pos != std::string::npos) { + std::istringstream str(line.data() + pos + 13); + int aux; + str >> aux; + if (!str.fail()) pg_size = aux; + break; + } + } + } + return pg_size * 1024; + } +#endif + + + static void align_range(size_t * first, size_t * last, size_t page_size) { + size_t offset_in_page = *first & (page_size - 1); + size_t offset_to_page = offset_in_page == 0 ? 0 : page_size - offset_in_page; + *first += offset_to_page; + + *last = *last & ~(page_size - 1); + + if (*last <= *first) { + *last = *first; + } + } + + void unmap_fragment(size_t first, size_t last) { + int page_size = mapped_page_size > 0 ? mapped_page_size : sysconf(_SC_PAGESIZE); + align_range(&first, &last, page_size); + size_t len = last - first; + + if (len == 0) { + return; + } + + GGML_ASSERT(first % page_size == 0); + GGML_ASSERT(last % page_size == 0); + GGML_ASSERT(last > first); + + void * next_page_start = (uint8_t *) addr + first; + + if (munmap(next_page_start, len)) { + LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno)); + } + + std::vector> new_mapped_fragments; + for (const auto & frag : mapped_fragments) { + if (frag.first < first && frag.second > last) { + new_mapped_fragments.emplace_back(frag.first, first); + new_mapped_fragments.emplace_back(last, frag.second); + } else if (frag.first < first && frag.second > first) { + new_mapped_fragments.emplace_back(frag.first, first); + } else if (frag.first < last && frag.second > last) { + new_mapped_fragments.emplace_back(last, frag.second); + } else if (frag.first >= first && frag.second <= last) { + } else { + new_mapped_fragments.push_back(frag); + } + } + mapped_fragments = std::move(new_mapped_fragments); + } + + ~impl() { + for (const auto & frag : mapped_fragments) { + if (munmap((char *) addr + frag.first, frag.second - frag.first)) { + LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno)); + } + } + } +#elif defined(_WIN32) + impl(struct llama_file * file, size_t prefetch, bool numa, [[maybe_unused]] bool use_thp) { + GGML_UNUSED(numa); + + size = file->size(); + + HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id()); + + HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + + if (hMapping == NULL) { + DWORD error = GetLastError(); + throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str())); + } + + addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); + DWORD error = GetLastError(); + CloseHandle(hMapping); + + if (addr == NULL) { + throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str())); + } + + if (prefetch > 0) { +#if _WIN32_WINNT >= 0x602 + BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG); + HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll"); + + pPrefetchVirtualMemory = (decltype(pPrefetchVirtualMemory))(void *) GetProcAddress(hKernel32, "PrefetchVirtualMemory"); + + if (pPrefetchVirtualMemory) { + WIN32_MEMORY_RANGE_ENTRY range; + range.VirtualAddress = addr; + range.NumberOfBytes = (SIZE_T) std::min(size, prefetch); + if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) { + LLAMA_LOG_WARN("warning: PrefetchVirtualMemory failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + LLAMA_LOG_DEBUG("skipping PrefetchVirtualMemory because _WIN32_WINNT < 0x602\n"); +#endif + } + } + + void unmap_fragment(size_t first, size_t last) { + GGML_UNUSED(first); + GGML_UNUSED(last); + } + + ~impl() { + if (!UnmapViewOfFile(addr)) { + LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + impl(struct llama_file * file, size_t prefetch, bool numa, [[maybe_unused]] bool use_thp) { + GGML_UNUSED(file); + GGML_UNUSED(prefetch); + GGML_UNUSED(numa); + + throw std::runtime_error("mmap not supported"); + } + + void unmap_fragment(size_t first, size_t last) { + GGML_UNUSED(first); + GGML_UNUSED(last); + + throw std::runtime_error("mmap not supported"); + } +#endif + + void * addr; + size_t size; + size_t mapped_page_size = 0; +}; + +llama_mmap::llama_mmap(struct llama_file * file, size_t prefetch, bool numa, bool use_thp) : + pimpl(std::make_unique(file, prefetch, numa, use_thp)) {} +llama_mmap::~llama_mmap() = default; + +size_t llama_mmap::size() const { return pimpl->size; } +void * llama_mmap::addr() const { return pimpl->addr; } + +void llama_mmap::unmap_fragment(size_t first, size_t last) { pimpl->unmap_fragment(first, last); } + +#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32) +const bool llama_mmap::SUPPORTED = true; +#else +const bool llama_mmap::SUPPORTED = false; +#endif + +// llama_mlock + +struct llama_mlock::impl { +#ifdef _POSIX_MEMLOCK_RANGE + static size_t lock_granularity() { + return (size_t) sysconf(_SC_PAGESIZE); + } + + bool raw_lock(const void * addr, size_t size) const { + if (!mlock(addr, size)) { + return true; + } + +#ifdef __APPLE__ +#define MLOCK_SUGGESTION \ + "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \ + "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MEMLOCK (ulimit -l).\n" +#else +#define MLOCK_SUGGESTION \ + "Try increasing RLIMIT_MEMLOCK ('ulimit -l' as root).\n" +#endif + + char* errmsg = std::strerror(errno); + bool suggest = (errno == ENOMEM); +#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX) + // visionOS/tvOS dont't support RLIMIT_MEMLOCK + // Skip resource limit checks on visionOS/tvOS + suggest = false; +#else + struct rlimit lock_limit; + if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) { + suggest = false; + } + if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) { + suggest = false; + } +#endif + + LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s", + size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : ""); + return false; + } + + static void raw_unlock(void * addr, size_t size) { + if (munlock(addr, size)) { + LLAMA_LOG_WARN("warning: failed to munlock buffer: %s\n", std::strerror(errno)); + } + } +#elif defined(_WIN32) + static size_t lock_granularity() { + SYSTEM_INFO si; + GetSystemInfo(&si); + return (size_t) si.dwPageSize; + } + + bool raw_lock(void * ptr, size_t len) const { + for (int tries = 1; ; tries++) { + if (VirtualLock(ptr, len)) { + return true; + } + if (tries == 2) { + LLAMA_LOG_WARN("warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n", + len, size, llama_format_win_err(GetLastError()).c_str()); + return false; + } + + SIZE_T min_ws_size, max_ws_size; + if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) { + LLAMA_LOG_WARN("warning: GetProcessWorkingSetSize failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + return false; + } + size_t increment = len + 1048576; + min_ws_size += increment; + max_ws_size += increment; + if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) { + LLAMA_LOG_WARN("warning: SetProcessWorkingSetSize failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + return false; + } + } + } + + static void raw_unlock(void * ptr, size_t len) { + if (!VirtualUnlock(ptr, len)) { + LLAMA_LOG_WARN("warning: failed to VirtualUnlock buffer: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + static size_t lock_granularity() { + return (size_t) 65536; + } + + bool raw_lock(const void * addr, size_t len) const { + LLAMA_LOG_WARN("warning: mlock not supported on this system\n"); + return false; + } + + static void raw_unlock(const void * addr, size_t len) {} +#endif + + impl() : addr(NULL), size(0), failed_already(false) {} + + void init(void * ptr) { + GGML_ASSERT(addr == NULL && size == 0); + addr = ptr; + } + + void grow_to(size_t target_size) { + GGML_ASSERT(addr); + if (failed_already) { + return; + } + size_t granularity = lock_granularity(); + target_size = (target_size + granularity - 1) & ~(granularity - 1); + if (target_size > size) { + if (raw_lock((uint8_t *) addr + size, target_size - size)) { + size = target_size; + } else { + failed_already = true; + } + } + } + + void * addr; + size_t size; + + bool failed_already; +}; + +llama_mlock::llama_mlock() : pimpl(std::make_unique()) {} +llama_mlock::~llama_mlock() = default; + +void llama_mlock::init(void * ptr) { pimpl->init(ptr); } +void llama_mlock::grow_to(size_t target_size) { pimpl->grow_to(target_size); } + +#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32) +const bool llama_mlock::SUPPORTED = true; +#else +const bool llama_mlock::SUPPORTED = false; +#endif + +size_t llama_path_max() { + return PATH_MAX; +} diff --git a/src/llama-mmap.h b/src/llama-mmap.h new file mode 100644 index 000000000..a1efa068f --- /dev/null +++ b/src/llama-mmap.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +#include + +struct llama_file; +struct llama_mmap; +struct llama_mlock; + +using llama_files = std::vector>; +using llama_mmaps = std::vector>; +using llama_mlocks = std::vector>; + +struct llama_file { + llama_file(const char * fname, const char * mode); + ~llama_file(); + + size_t tell() const; + size_t size() const; + + int file_id() const; // fileno overload + + void seek(size_t offset, int whence) const; + + void read_raw(void * ptr, size_t len) const; + uint32_t read_u32() const; + + void write_raw(const void * ptr, size_t len) const; + void write_u32(uint32_t val) const; + +private: + struct impl; + std::unique_ptr pimpl; +}; + +struct llama_mmap { + llama_mmap(const llama_mmap &) = delete; + llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false, bool use_thp = false); + ~llama_mmap(); + + size_t size() const; + void * addr() const; + + void unmap_fragment(size_t first, size_t last); + + static const bool SUPPORTED; + +private: + struct impl; + std::unique_ptr pimpl; +}; + +struct llama_mlock { + llama_mlock(); + ~llama_mlock(); + + void init(void * ptr); + void grow_to(size_t target_size); + + static const bool SUPPORTED; + +private: + struct impl; + std::unique_ptr pimpl; +}; + +size_t llama_path_max(); diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp new file mode 100644 index 000000000..915764615 --- /dev/null +++ b/src/llama-model-loader.cpp @@ -0,0 +1,1082 @@ +#include "llama-model-loader.h" +#include "llama-impl.h" +#include "llama-mmap.h" +#include "ggml.h" +//#include "ggml-backend.h" + +#ifdef GGML_USE_CUDA +# include "ggml-cuda.h" +#elif defined(GGML_USE_VULKAN) +# include "ggml-vulkan.h" +#elif defined(GGML_USE_SYCL) +# include "ggml-sycl.h" +#elif defined(GGML_USE_KOMPUTE) +# include "ggml-kompute.h" +#elif defined(GGML_USE_CANN) +# include "ggml-cann.h" +#endif + +#include +#include +#include +#include + +#if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #ifndef NOMINMAX + #define NOMINMAX + #endif + #include + #ifndef PATH_MAX + #define PATH_MAX MAX_PATH + #endif + #include +#endif + +#define LLAMA_API_INTERNAL + +namespace GGUFMeta { + template + struct GKV_Base_Type { + static constexpr gguf_type gt = gt_; + + static T getter(const gguf_context * ctx, const int kid) { + return gfun(ctx, kid); + } + }; + + template struct GKV_Base; + + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + + template<> struct GKV_Base { + static constexpr gguf_type gt = GGUF_TYPE_STRING; + + static std::string getter(const gguf_context * ctx, const int kid) { + return gguf_get_val_str(ctx, kid); + } + }; + + struct ArrayInfo { + const gguf_type gt; + const size_t length; + const void * data; + }; + + template<> struct GKV_Base { + public: + static constexpr gguf_type gt = GGUF_TYPE_ARRAY; + static ArrayInfo getter(const gguf_context *ctx, const int k) { + return ArrayInfo { + gguf_get_arr_type(ctx, k), + size_t(gguf_get_arr_n(ctx, k)), + gguf_get_arr_data(ctx, k), + }; + } + }; + + template + class GKV : public GKV_Base { + GKV() = delete; + + public: + static T get_kv(const gguf_context * ctx, const int k) { + const enum gguf_type kt = gguf_get_kv_type(ctx, k); + + if (kt != GKV::gt) { + throw std::runtime_error(format("key %s has wrong type %s but expected type %s", + gguf_get_key(ctx, k), gguf_type_name(kt), gguf_type_name(GKV::gt))); + } + return GKV::getter(ctx, k); + } + + static const char * override_type_to_str(const llama_model_kv_override_type ty) { + switch (ty) { + case LLAMA_KV_OVERRIDE_TYPE_BOOL: return "bool"; + case LLAMA_KV_OVERRIDE_TYPE_INT: return "int"; + case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float"; + case LLAMA_KV_OVERRIDE_TYPE_STR: return "str"; + } + return "unknown"; + } + + static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override * ovrd) { + if (!ovrd) { return false; } + if (ovrd->tag == expected_type) { + LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ", + __func__, override_type_to_str(ovrd->tag), ovrd->key); + switch (ovrd->tag) { + case LLAMA_KV_OVERRIDE_TYPE_BOOL: { + LLAMA_LOG_INFO("%s\n", ovrd->val_bool ? "true" : "false"); + } break; + case LLAMA_KV_OVERRIDE_TYPE_INT: { + LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->val_i64); + } break; + case LLAMA_KV_OVERRIDE_TYPE_FLOAT: { + LLAMA_LOG_INFO("%.6f\n", ovrd->val_f64); + } break; + case LLAMA_KV_OVERRIDE_TYPE_STR: { + LLAMA_LOG_INFO("%s\n", ovrd->val_str); + } break; + default: + // Shouldn't be possible to end up here, but just in case... + throw std::runtime_error( + format("Unsupported attempt to override %s type for metadata key %s\n", + override_type_to_str(ovrd->tag), ovrd->key)); + } + return true; + } + LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n", + __func__, ovrd->key, override_type_to_str(expected_type), override_type_to_str(ovrd->tag)); + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(OT & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) { + target = ovrd->val_bool; + return true; + } + return false; + } + + template + static typename std::enable_if::value && std::is_integral::value, bool>::type + try_override(OT & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) { + target = ovrd->val_i64; + return true; + } + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(T & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) { + target = ovrd->val_f64; + return true; + } + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(T & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) { + target = ovrd->val_str; + return true; + } + return false; + } + + static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) { + if (try_override(target, ovrd)) { + return true; + } + if (k < 0) { return false; } + target = get_kv(ctx, k); + return true; + } + + static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override * ovrd = nullptr) { + return set(ctx, gguf_find_key(ctx, key), target, ovrd); + } + + static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override * ovrd = nullptr) { + return set(ctx, key.c_str(), target, ovrd); + } + }; +} + +llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp, + const llama_model_kv_override * param_overrides_p, + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) { + int trace = 0; + if (getenv("LLAMA_TRACE")) { + trace = atoi(getenv("LLAMA_TRACE")); + } + +#ifdef _WIN32 + // Only bump maxstdio if the user really wants large contexts: +#if defined(GGML_MAX_CONTEXTS) && (GGML_MAX_CONTEXTS > 512) + // Cap at MSVC's hard limit of 8192 - https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/setmaxstdio?view=msvc-160 +#if (GGML_MAX_CONTEXTS > 8192) +#define _GGML_STDIO_TARGET 8192 +#else +#define _GGML_STDIO_TARGET GGML_MAX_CONTEXTS +#endif + int _setmaxstdio_ret = _setmaxstdio(_GGML_STDIO_TARGET); + if (_setmaxstdio_ret == -1) { + LLAMA_LOG_INFO("%s: failed to set max stdio to %d. (setmaxstdio returned -1)\n", __func__, _GGML_STDIO_TARGET); + } else { + LLAMA_LOG_INFO("%s: max stdio successfully set to %d\n", __func__, _setmaxstdio_ret); + } +#endif // GGML_MAX_CONTEXTS > 512 +#endif // _WIN32 + + if (param_overrides_p != nullptr) { + for (const struct llama_model_kv_override * p = param_overrides_p; p->key[0] != 0; p++) { + kv_overrides.insert({std::string(p->key), *p}); + } + } + + tensor_buft_overrides = param_tensor_buft_overrides_p; + + struct ggml_context * ctx = NULL; + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; + + meta = gguf_init_from_file(fname.c_str(), params); + if (!meta) { + throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str())); + } + + get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); + llm_kv = LLM_KV(llm_arch_from_string(arch_name)); + + files.emplace_back(new llama_file(fname.c_str(), "rb")); + contexts.emplace_back(ctx); + + // Save tensors data offset of the main file. + // For subsidiary files, `meta` tensor data offset must not be used, + // so we build a unified tensors index for weights. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + weights.emplace_back(files.back().get(), 0, cur->name, meta, cur); + } + uint16_t n_split = 0; + get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false); + + // Load additional GGML contexts + if (n_split > 1) { + uint16_t idx = 0; + get_key(llm_kv(LLM_KV_SPLIT_NO), idx); + if (idx != 0) { + throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", idx)); + } + + char split_prefix[PATH_MAX] = {0}; + if (!llama_split_prefix(split_prefix, sizeof(split_prefix), fname.c_str(), idx, n_split)) { + throw std::runtime_error(format("invalid split file: %s", fname.c_str())); + } + + if (trace > 0) { + LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split); + } + + char split_path[PATH_MAX] = {0}; + for (idx = 1; idx < n_split; idx++) { + llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split); + + struct gguf_init_params split_params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; + struct gguf_context * ctx_gguf = gguf_init_from_file(split_path, split_params); + if (!ctx_gguf) { + throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path)); + } + + files.emplace_back(new llama_file(split_path, "rb")); + contexts.emplace_back(ctx); + + // Save tensors data offset info of the shard. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + weights.emplace_back(files.back().get(), idx, cur->name, ctx_gguf, cur); + } + + gguf_free(ctx_gguf); + } + + get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors); + + // sanity check + { + const int n_tensors_loaded = (int) weights.size(); + if (n_tensors != n_tensors_loaded) { + throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded)); + } + } + + LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1); + } + + n_kv = gguf_get_n_kv(meta); + n_tensors = weights.size(); + + fver = (enum llama_fver) gguf_get_version(meta); + + std::set tensor_names; + for (auto & w : weights) { + n_elements += ggml_nelements(w.tensor); + n_bytes += ggml_nbytes(w.tensor); + // make sure there is no duplicated tensor names + const std::string name(w.tensor->name); + auto found = tensor_names.find(name); + if (found != tensor_names.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", w.tensor->name)); + } + tensor_names.insert(name); + } + + LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", + __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver)); + + // determine file type based on the number of tensors for each quantization and print meta data + // TODO: make optional + { + std::map n_type; + + uint32_t n_type_max = 0; + enum ggml_type type_max = GGML_TYPE_F32; + + for (int i = 0; i < n_tensors; i++) { + const ggml_tensor * tensor = weights.at(i).tensor; + enum ggml_type type = tensor->type; + + n_type[type]++; + + if (n_type_max < n_type[type]) { + n_type_max = n_type[type]; + type_max = type; + } + + if (trace > 0) { + const uint16_t sid = weights.at(i).idx; + LLAMA_LOG_INFO("%s: - tensor %4d, split %2d: %32s %-8s [ %s ]\n", __func__, i, sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str()); + } + } + + switch (type_max) { + case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break; + case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break; + case GGML_TYPE_BF16: ftype = LLAMA_FTYPE_MOSTLY_BF16; break; + case GGML_TYPE_BF16_R16:ftype = LLAMA_FTYPE_MOSTLY_BF16_R16;break; + case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break; + case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break; + case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break; + case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break; + case GGML_TYPE_Q6_0: ftype = LLAMA_FTYPE_MOSTLY_Q6_0; break; + case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break; + case GGML_TYPE_Q8_KV: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV; break; + case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break; + case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break; + case GGML_TYPE_Q3_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_R4; break; + case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break; + case GGML_TYPE_Q4_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_R4; break; + case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break; + case GGML_TYPE_Q5_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_R4; break; + case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; + case GGML_TYPE_Q6_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_K_R4; break; + case GGML_TYPE_Q8_K_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_K_R8; break; + case GGML_TYPE_Q8_KV_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV_R8; break; + case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break; + case GGML_TYPE_IQ2_XXS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4; break; + case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break; + case GGML_TYPE_IQ2_XS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS_R4; break; + case GGML_TYPE_IQ2_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_KS; break; + case GGML_TYPE_IQ2_S: ftype = LLAMA_FTYPE_MOSTLY_IQ2_M; break; + case GGML_TYPE_IQ2_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_M_R4;break; + case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break; + case GGML_TYPE_IQ3_XXS_R4: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4; break; + case GGML_TYPE_IQ1_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ1_KT; break; + case GGML_TYPE_IQ2_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ2_KT; break; + case GGML_TYPE_IQ3_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ3_KT; break; + case GGML_TYPE_IQ4_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KT; break; + case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break; + case GGML_TYPE_IQ1_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ1_S_R4;break; + case GGML_TYPE_IQ1_M_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ1_M_R4;break; + case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break; + case GGML_TYPE_IQ1_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ1_BN; break; + case GGML_TYPE_IQ2_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN; break; + case GGML_TYPE_IQ2_BN_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN_R4;break; + case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; + case GGML_TYPE_IQ4_NL_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL_R4;break; + case GGML_TYPE_IQ4_XS_R8:ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS_R8;break; + case GGML_TYPE_Q4_0_R8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_R8; break; + case GGML_TYPE_Q5_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q5_0_R4; break; + case GGML_TYPE_Q6_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_0_R4; break; + case GGML_TYPE_Q8_0_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_0_R8; break; + case GGML_TYPE_MXFP4: ftype = LLAMA_FTYPE_MOSTLY_MXFP4; break; + case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; + case GGML_TYPE_IQ4_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS; break; + case GGML_TYPE_IQ4_KS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS_R4; break; + case GGML_TYPE_IQ5_KS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ5_KS_R4; break; + case GGML_TYPE_IQ4_KSS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KSS; break; + case GGML_TYPE_IQ5_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ5_KS; break; + case GGML_TYPE_IQ2_K: ftype = LLAMA_FTYPE_MOSTLY_IQ2_K; break; + case GGML_TYPE_IQ2_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_K_R4;break; + case GGML_TYPE_IQ3_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_KS; break; + case GGML_TYPE_IQ2_KL: ftype = LLAMA_FTYPE_MOSTLY_IQ2_KL; break; + case GGML_TYPE_IQ3_K: ftype = LLAMA_FTYPE_MOSTLY_IQ3_K; break; + case GGML_TYPE_IQ3_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ3_K_R4;break; + case GGML_TYPE_IQ4_K: ftype = LLAMA_FTYPE_MOSTLY_IQ4_K; break; + case GGML_TYPE_IQ4_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_K_R4;break; + case GGML_TYPE_IQ5_K: ftype = LLAMA_FTYPE_MOSTLY_IQ5_K; break; + case GGML_TYPE_IQ5_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ5_K_R4;break; + case GGML_TYPE_IQ6_K: ftype = LLAMA_FTYPE_MOSTLY_IQ6_K; break; + case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; + case GGML_TYPE_IQ3_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ3_S_R4;break; + case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break; + case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break; + case GGML_TYPE_Q4_0_8_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_8_8; break; + default: + { + LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); + ftype = LLAMA_FTYPE_ALL_F32; + } break; + } + + // this is a way to mark that we have "guessed" the file type + ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED); + + { + const int kid = gguf_find_key(meta, "general.file_type"); // TODO: use LLM_KV + if (kid >= 0) { + ftype = (llama_ftype) gguf_get_val_u32(meta, kid); + } + } + + LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__); + + for (int i = 0; i < n_kv; i++) { + const char * name = gguf_get_key(meta, i); + const enum gguf_type type = gguf_get_kv_type(meta, i); + const std::string type_name = + type == GGUF_TYPE_ARRAY + ? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta, i)), gguf_get_arr_n(meta, i)) + : gguf_type_name(type); + + std::string value = gguf_kv_to_str(meta, i); + const size_t MAX_VALUE_LEN = 40; + if (value.size() > MAX_VALUE_LEN) { + value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()); + } + replace_all(value, "\n", "\\n"); + + LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str()); + } + + // print type counts + for (auto & kv : n_type) { + if (kv.second == 0) { + continue; + } + + LLAMA_LOG_INFO("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second); + } + } + + if (!llama_mmap::SUPPORTED) { + LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__); + use_mmap = false; + } + if (repack_tensors) { + use_mmap = false; + } + + this->use_mmap = use_mmap; + this->check_tensors = check_tensors; + this->repack_tensors = repack_tensors; + this->use_thp = use_thp; +} + +llama_model_loader::~llama_model_loader() { + if (meta) { + gguf_free(meta); + } + for (auto * ctx : contexts) { + ggml_free(ctx); + } +} + +template +typename std::enable_if::value, bool>::type +llama_model_loader::get_arr_n(const std::string & key, T & result, const bool required) { + const int kid = gguf_find_key(meta, key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta, kid); + + + result = arr_info.length; + return true; +} + +template +typename std::enable_if::value, bool>::type +llama_model_loader::get_arr_n(const enum llm_kv kid, T & result, const bool required) { + return get_arr_n(llm_kv(kid), result, required); +} + +template +bool llama_model_loader::get_arr(const std::string & key, std::vector & result, const bool required) { + const int kid = gguf_find_key(meta, key.c_str()); + + if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) { + if (required) { + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta, kid); + + switch (arr_info.gt) { + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_INT32: GGML_ASSERT( + (std::is_same::value) || + (std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + } + + result.resize(arr_info.length); + result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); + + return true; +} + +template +bool llama_model_loader::get_arr(const std::string & key, std::array & result, const bool required) { + const int kid = gguf_find_key(meta, key.c_str()); + + if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) { + if (required) { + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta, kid); + + switch (arr_info.gt) { + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_INT32: GGML_ASSERT( + (std::is_same::value) || + (std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + } + + if (arr_info.length > N_MAX) { + throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX)); + } + + std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + + return true; +} + +template +bool llama_model_loader::get_arr(const enum llm_kv kid, T & result, const bool required) { + return get_arr(llm_kv(kid), result, required); +} + +template +bool llama_model_loader::get_key(const std::string & key, T & result, const bool required) { + auto it = kv_overrides.find(key); + + const struct llama_model_kv_override * override = + it != kv_overrides.end() ? &it->second : nullptr; + + const bool found = GGUFMeta::GKV::set(meta, key, result, override); + + if (required && !found) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + + return found; +} + +template +bool llama_model_loader::get_key(const enum llm_kv kid, T & result, const bool required) { + return get_key(llm_kv(kid), result, required); +} + +// get array of n <= N_MAX elements, or a single element repeated n times +template +bool llama_model_loader::get_key_or_arr(const std::string & key, std::array & result, uint32_t n, const bool required) { + const int kid = gguf_find_key(meta, key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + if (n > N_MAX) { + throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str())); + } + + if (gguf_get_kv_type(meta, kid) == GGUF_TYPE_ARRAY) { + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta, kid); + + if (n != arr_info.length) { + throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length)); + } + + return get_arr(key, result, required); + } else { + T value; + + bool ok = get_key(key, value, required); + if (!ok) { + return false; + } + + for (uint32_t i = 0; i < n; i++) { + result[i] = value; + } + + return true; + } +} + +template +bool llama_model_loader::get_key_or_arr(const enum llm_kv kid, T & result, uint32_t n, const bool required) { + return get_key_or_arr(llm_kv(kid), result, n, required); +} + +const char * llama_model_loader::get_tensor_name(int i) const { + return weights.at(i).tensor->name; +} + +const llama_model_loader::llama_tensor_weight * llama_model_loader::get_weight(const char * name) const { + for (const auto & weight : weights) { + if (strcmp(name, weight.tensor->name) == 0) { + return &weight; + } + } + return nullptr; +} + +const llama_model_loader::llama_tensor_weight & llama_model_loader::require_weight(const char * name) const { + const llama_tensor_weight * weight = get_weight(name); + if (!weight) { + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name)); + } + return *weight; +} + +struct ggml_tensor * llama_model_loader::get_tensor_meta(const char * name) const { + const auto * weight = get_weight(name); + if (!weight) { + return nullptr; + } + return weight->tensor; +} + +struct ggml_tensor * llama_model_loader::require_tensor_meta(const char * name) const { + struct ggml_tensor * tensor = get_tensor_meta(name); + if (!tensor) { + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name)); + } + return tensor; +} + +struct ggml_tensor * llama_model_loader::create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur, bool duplicated) { + struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur); + ggml_set_name(tensor, ggml_get_name(cur)); + + if (duplicated) { + size_data += ggml_nbytes(cur); + } else { + n_created++; + } + + return tensor; +} + +const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::string & name, const std::vector & ne, bool required) const { + const struct ggml_tensor * cur = get_tensor_meta(name.c_str()); + + if (cur == NULL) { + if (!required) { + return NULL; + } + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str())); + } + + { + bool is_ok = true; + for (size_t i = 0; i < GGML_MAX_DIMS; ++i) { + if ((i < ne.size() && ne[i] != cur->ne[i]) || (i >= ne.size() && cur->ne[i] != 1)) { + is_ok = false; + break; + } + } + if (!is_ok) { + throw std::runtime_error( + format("%s: tensor '%s' has wrong shape; expected %s, got %s", + __func__, name.c_str(), + llama_format_tensor_shape(ne).c_str(), + llama_format_tensor_shape(cur).c_str())); + } + } + + return cur; +} + +struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, + const std::vector & ne, int flags) { + const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED)); + + if (cur == NULL) { + return NULL; + } + + // skip unused tensors + if (flags & TENSOR_SKIP) { + const size_t nbytes = ggml_nbytes(cur); + LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", name.c_str(), nbytes); + + size_data -= nbytes; + n_created++; + + return nullptr; + } + + return create_tensor_for(ctx, cur, flags & TENSOR_DUPLICATED); +} + +struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, + const std::string & name, const std::vector & ne, size_t offset, bool required) { + const struct ggml_tensor * cur = check_tensor_dims(name, ne, required); + + if (cur == NULL) { + return NULL; + } + + if (cur->type != base->type) { + throw std::runtime_error(format("%s: tensor '%s' has wrong type; expected %s, got %s", __func__, name.c_str(), ggml_type_name(base->type), ggml_type_name(cur->type))); + } + + std::array dims; + for (size_t i = 0; i < GGML_MAX_DIMS; ++i) { + dims[i] = i < ne.size() ? ne[i] : 1; + } + + struct ggml_tensor * tensor = ggml_view_4d(ctx, base, + dims[0], dims[1], dims[2], dims[3], + cur->nb[1], cur->nb[2], cur->nb[3], + offset); + + ggml_set_name(tensor, name.c_str()); + + n_created++; + + return tensor; +} + +void llama_model_loader::done_getting_tensors() const { + if (n_created != n_tensors) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } +} + +void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps, bool use_thp) { + if (use_mmap) { + mappings.reserve(files.size()); + mmaps_used.reserve(files.size()); + for (const auto & file : files) { + std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, ggml_is_numa(), use_thp)); + mmaps_used.emplace_back(mapping->size(), 0); + if (mlock_mmaps) { + std::unique_ptr mlock_mmap(new llama_mlock()); + mlock_mmap->init(mapping->addr()); + mlock_mmaps->emplace_back(std::move(mlock_mmap)); + } + mappings.emplace_back(std::move(mapping)); + } + } + + // compute the total size of all tensors for progress reporting + for (auto & w : weights) { + size_data += ggml_nbytes(w.tensor); + } +} + +void llama_model_loader::get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const { + GGML_ASSERT(!mappings.empty()); + const auto & mapping = mappings.at(idx); + + *first = mapping->size(); + *last = 0; + *addr = mapping->addr(); + for (ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor; tensor = ggml_get_next_tensor(ctx, tensor)) { + try { + const auto * weight = get_weight(ggml_get_name(tensor)); + if (!weight) { + continue; + } + if (weight->idx != idx) { + continue; + } + *first = std::min(*first, weight->offs); + *last = std::max(*last, weight->offs + ggml_nbytes(tensor)); + } catch(...) { + // the tensor is not in the model + } + } +} + +// for backwards compatibility, does not support ggml-backend +void llama_model_loader::load_data_for(struct ggml_tensor * cur) const { + const auto & w = require_weight(ggml_get_name(cur)); + + if (use_mmap) { + const auto & mapping = mappings.at(w.idx); + if (cur->data == nullptr) { + cur->data = (uint8_t *)mapping->addr() + w.offs; + } else { + memcpy(cur->data, (uint8_t *)mapping->addr() + w.offs, ggml_nbytes(cur)); + } + } else { + GGML_ASSERT(cur->data != nullptr); + GGML_ASSERT(w.idx < files.size()); + const auto & file = files.at(w.idx); + file->seek(w.offs, SEEK_SET); + file->read_raw(cur->data, ggml_nbytes(cur)); + } + + if (check_tensors && !ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) { + throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + } +} + +// Returns false if cancelled by progress_callback +bool llama_model_loader::load_all_data( + struct ggml_context * ctx, + llama_buf_map & bufs_mmap, + llama_mlocks * lmlocks, + llama_progress_callback progress_callback, + void * progress_callback_user_data) { + GGML_ASSERT(size_data != 0 && "call init_mappings() first"); + + std::vector> read_buf; + std::vector>> validation_result; + +#if defined(GGML_USE_CUDA) + // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives. + // NVMe raid configurations might require more / larger buffers. + constexpr size_t n_buffers = 4; + constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB + + std::vector host_buffers; + std::vector host_ptrs; + std::vector events; + size_t buffer_idx = 0; // buffer to use for async loads + + ggml_backend_t cuda_backend = nullptr; + if (!use_mmap && !check_tensors) { + // When not using mmaped io use async uploads from pinned memory to GPU memory. + // First determine if the CUDA backend is active, and if so, determine the device ID. + ggml_backend_buffer_t buf = bufs_mmap.count(0) ? bufs_mmap.at(0) : nullptr; + if (buf) { + ggml_backend_buffer_type_t buffer_type = ggml_backend_buffer_get_type(buf); + for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) { + auto * cuda_buffer_type = ggml_backend_cuda_buffer_type(i); + if (buffer_type == cuda_buffer_type) { + cuda_backend = ggml_backend_cuda_init(i); + break; + } + } + } + + // If the cuda backend is active create pinned memory buffers and events for synchronisation. + if (cuda_backend) { + for (size_t idx = 0; idx < n_buffers; ++idx) { + host_buffers.emplace_back(ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buffer_size)); + host_ptrs.emplace_back(ggml_backend_buffer_get_base(host_buffers[idx])); + events.emplace_back(ggml_backend_event_new(cuda_backend)); + } + } + } +#endif + + for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) { + const auto * weight = get_weight(ggml_get_name(cur)); + if (weight == nullptr) { + // this can happen with split experts models + continue; + } + + if (progress_callback) { + if (!progress_callback((float) size_done / size_data, progress_callback_user_data)) { + return false; + } + } + + size_t n_size = ggml_nbytes(cur); + + if (use_mmap) { + const auto & mapping = mappings.at(weight->idx); + ggml_backend_buffer_t buf_mmap = nullptr; + if (bufs_mmap.count(weight->idx)) { + buf_mmap = bufs_mmap.at(weight->idx); + } + uint8_t * data = (uint8_t *) mapping->addr() + weight->offs; + + if (check_tensors) { + validation_result.emplace_back(std::async(std::launch::async, [cur, data, n_size] { + return std::make_pair(cur, ggml_validate_row_data(cur->type, data, n_size)); + })); + } + + GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated + if (buf_mmap && cur->data == nullptr) { + ggml_backend_tensor_alloc(buf_mmap, cur, data); + if (lmlocks) { + const auto & lmlock = lmlocks->at(weight->idx); + lmlock->grow_to(weight->offs + n_size); + } + + auto & mmap_used = mmaps_used[weight->idx]; + mmap_used.first = std::min(mmap_used.first, weight->offs); + mmap_used.second = std::max(mmap_used.second, weight->offs + n_size); + } else { + ggml_backend_tensor_set(cur, data, 0, n_size); + } + } else { + GGML_ASSERT(weight->idx < files.size()); + const auto & file = files.at(weight->idx); + if (ggml_backend_buffer_is_host(cur->buffer)) { + file->seek(weight->offs, SEEK_SET); + file->read_raw(cur->data, n_size); + if (check_tensors) { + validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] { + return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size)); + })); + } + } else { +#if defined(GGML_USE_CUDA) + // If cuda_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU. + if (cuda_backend) { + file->seek(weight->offs, SEEK_SET); + + size_t bytes_read = 0; + + while (bytes_read < n_size) { + size_t read_iteration = std::min(buffer_size, n_size - bytes_read); + + ggml_backend_event_synchronize(events[buffer_idx]); + file->read_raw(host_ptrs[buffer_idx], read_iteration); + ggml_backend_tensor_set_async(cuda_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration); + ggml_backend_event_record(events[buffer_idx]); + + bytes_read += read_iteration; + ++buffer_idx; + buffer_idx %= n_buffers; + } + } + else +#endif + { + read_buf.resize(n_size); + file->seek(weight->offs, SEEK_SET); + file->read_raw(read_buf.data(), n_size); + ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); + if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { + throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + } + } + } + } + + size_done += n_size; + } + +#if defined(GGML_USE_CUDA) + // free temporary resources used for async cuda uploads + if (cuda_backend) { + for (size_t idx = 0; idx < n_buffers;++idx) { + ggml_backend_event_synchronize(events[idx]); + ggml_backend_event_free(events[idx]); + ggml_backend_buffer_free(host_buffers[idx]); + } + ggml_backend_free(cuda_backend); + } +#endif + + // check validation results + bool validation_failed = false; + for (auto & future : validation_result) { + auto result = future.get(); + if (!result.second) { + LLAMA_LOG_ERROR("%s: tensor '%s' has invalid data\n", __func__, ggml_get_name(result.first)); + validation_failed = true; + } + } + if (validation_failed) { + throw std::runtime_error("found tensors with invalid data"); + } + + // check if this is the last call and do final cleanup + if (size_done >= size_data) { + // unmap offloaded tensors and metadata + if (use_mmap) { + for (uint32_t idx = 0; idx < mappings.size(); idx++) { + const auto & mmap_used = mmaps_used.at(idx); + auto & mapping = mappings.at(idx); + mapping->unmap_fragment(0, mmap_used.first); + if (mmap_used.second != 0) { + mapping->unmap_fragment(mmap_used.second, mapping->size()); + } + } + } + if (progress_callback) { + // Even though the model is done loading, we still honor + // cancellation since we need to free allocations. + return progress_callback(1.0f, progress_callback_user_data); + } + } + + return true; +} + +template<> +bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) { + uint32_t tmp; + const bool found = get_key(kid, tmp, required); + if (found) { + result = (enum llama_pooling_type) tmp; + } else { + result = LLAMA_POOLING_TYPE_UNSPECIFIED; + } + return found; +} +template bool llama_model_loader::get_key (enum llm_kv kid, bool & result, bool required); +template bool llama_model_loader::get_key (enum llm_kv kid, float & result, bool required); +template bool llama_model_loader::get_key (enum llm_kv kid, uint32_t & result, bool required); +template bool llama_model_loader::get_key(enum llm_kv kid, std::string & result, bool required); + +template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); +template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + +template std::enable_if::value, bool>::type llama_model_loader::get_arr_n(enum llm_kv, unsigned int&, bool); diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h new file mode 100644 index 000000000..4240b0d18 --- /dev/null +++ b/src/llama-model-loader.h @@ -0,0 +1,169 @@ +#pragma once + +#include "llama.h" +#include "llama-impl.h" +#include "llama-mmap.h" +#include "llama-arch.h" + +#include +#include +#include +#include +#include + +enum llama_fver { + GGUF_FILE_VERSION_V1 = 1, + GGUF_FILE_VERSION_V2 = 2, + GGUF_FILE_VERSION_V3 = 3, +}; + +static const char * llama_file_version_name(llama_fver version) { + switch (version) { + case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)"; + case GGUF_FILE_VERSION_V2: return "GGUF V2"; + case GGUF_FILE_VERSION_V3: return "GGUF V3 (latest)"; + } + + return "unknown"; +} + +using llama_buf_map = std::unordered_map; + +struct llama_model_loader { + int n_kv = 0; + int n_tensors = 0; + int n_created = 0; + + int64_t n_elements = 0; + size_t n_bytes = 0; + + bool use_mmap = false; + bool check_tensors; + bool repack_tensors = false; + bool use_thp = false; + + llama_files files; + llama_ftype ftype; + llama_fver fver; + + llama_mmaps mappings; + + // Holds information on a model weight + struct llama_tensor_weight { + uint16_t idx; // source file index + size_t offs; // tensor data offset in the original file + + ggml_tensor * tensor; + + llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) { + const int tensor_idx = gguf_find_tensor(gguf_ctx, name); + offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx); + + if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size()) { + throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name)); + } + } + }; + std::vector weights; + + std::unordered_map kv_overrides; + const llama_model_tensor_buft_override * tensor_buft_overrides; + + gguf_context * meta = NULL; + std::vector contexts; + + std::string arch_name; + LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); + + llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp, + const llama_model_kv_override * param_overrides_p, + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p); + + ~llama_model_loader(); + + template + typename std::enable_if::value, bool>::type + get_arr_n(const std::string & key, T & result, const bool required = true); + + template + typename std::enable_if::value, bool>::type + get_arr_n(const enum llm_kv kid, T & result, const bool required = true); + + template + bool get_arr(const std::string & key, std::vector & result, const bool required = true); + + template + bool get_arr(const std::string & key, std::array & result, const bool required = true); + + template + bool get_arr(const enum llm_kv kid, T & result, const bool required = true); + + template + bool get_key(const std::string & key, T & result, const bool required = true); + + template + bool get_key(const enum llm_kv kid, T & result, const bool required = true); + + // get array of n <= N_MAX elements, or a single element repeated n times + template + bool get_key_or_arr(const std::string & key, std::array & result, uint32_t n, const bool required = true); + + template + bool get_key_or_arr(const enum llm_kv kid, T & result, uint32_t n, const bool required = true); + + const std::string& get_arch_name() const { return arch_name; } + + enum llm_arch get_arch() const { return llm_kv.arch; } + + const char * get_tensor_name(int i) const; + + const llama_tensor_weight * get_weight(const char * name) const; + + const llama_tensor_weight * get_weight(int i) const { + return get_weight(get_tensor_name(i)); + } + + const llama_tensor_weight & require_weight(const char * name) const; + + struct ggml_tensor * get_tensor_meta(const char * name) const; + + struct ggml_tensor * require_tensor_meta(const char * name) const; + + struct ggml_tensor * get_tensor_meta(int i) const { + return get_tensor_meta(get_tensor_name(i)); + } + + struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur, bool duplicated); + + const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector & ne, bool required) const; + + static const int TENSOR_NOT_REQUIRED = 1 << 0; + static const int TENSOR_DUPLICATED = 1 << 1; + static const int TENSOR_SKIP = 1 << 2; + + struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, int flags = 0); + + struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, + const std::string & name, const std::vector & ne, size_t offset, bool required = true); + + void done_getting_tensors() const; + + void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr, bool use_thp = false); + + void get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const; + + // for backwards compatibility, does not support ggml-backend + void load_data_for(struct ggml_tensor * cur) const; + + size_t size_done = 0; + size_t size_data = 0; + std::vector> mmaps_used; + + // Returns false if cancelled by progress_callback + bool load_all_data( + struct ggml_context * ctx, + llama_buf_map & bufs_mmap, + llama_mlocks * lmlocks, + llama_progress_callback progress_callback, + void * progress_callback_user_data); +}; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 40d9963dc..0d806d8a8 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -734,7 +734,7 @@ llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_da // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) static void get_overlapping_token_sequences(const llama_vocab& vocab, const std::string& str, std::unordered_multimap>& token_sequences, int max_tail_len = -1) { for (llama_token token_id = 0; token_id < (llama_token)vocab.n_tokens(); token_id++) { - std::string word = llama_detokenize(vocab, { token_id }, true); + auto word = vocab.detokenize( { token_id }, true); if (word.find(str) != std::string::npos) { token_sequences.emplace(token_id, std::vector()); } @@ -751,7 +751,8 @@ static void get_overlapping_token_sequences(const llama_vocab& vocab, const std: } } if (match) { - std::vector tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false); + auto tokenization = vocab.tokenize(str.substr(i), false, false); + //std::vector tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false); if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) { tokenization.resize(max_tail_len); } diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a2bc72d9b..b92adfd4c 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1,37 +1,29 @@ #include "llama-vocab.h" +#include "ggml.h" +#include "llama-impl.h" +#include "llama-model-loader.h" + #include "unicode.h" #include #include +#include #include -#include +#include #include #include #include +#include +#include #include -#include +#include +#include // // helpers // -LLAMA_ATTRIBUTE_FORMAT(1, 2) -static std::string format(const char * fmt, ...) { - va_list ap; - va_list ap2; - va_start(ap, fmt); - va_copy(ap2, ap); - int size = vsnprintf(NULL, 0, fmt, ap); - GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT - std::vector buf(size + 1); - int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); - GGML_ASSERT(size2 == size); - va_end(ap2); - va_end(ap); - return std::string(buf.data(), size); -} - struct naive_trie { naive_trie() : has_value(false), value(0) { } @@ -50,7 +42,7 @@ struct naive_trie { res.first->second.insert(key + 1, len - 1, value); } } - std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) { + std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) const { if (len == 0 || offset == len) { return std::make_pair(key, offset); } @@ -58,107 +50,31 @@ struct naive_trie { auto res = children.find(c); if (res != children.end()) { return res->second.get_longest_prefix(key, len, offset + 1); - } else { - return std::make_pair(key, offset); } + + return std::make_pair(key, offset); } - struct naive_trie * traverse(const char c) { + const struct naive_trie * traverse(const char c) const { auto res = children.find(c); if (res != children.end()) { return &res->second; - } else { - return NULL; } + + return NULL; } std::map children; bool has_value; llama_token value; }; -uint32_t llama_vocab::n_tokens() const { - return (uint32_t)id_to_token.size(); -} // -// impl +// tokenizers // -int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const { - GGML_ASSERT(token_left.find(' ') == std::string::npos); - GGML_ASSERT(token_left.find('\n') == std::string::npos); - GGML_ASSERT(token_right.find(' ') == std::string::npos); - GGML_ASSERT(token_right.find('\n') == std::string::npos); - - auto it = bpe_ranks.find(std::make_pair(token_left, token_right)); - if (it == bpe_ranks.end()) { - return -1; - } - - return it->second; -} - -static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { - return vocab.type; -} - -static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL; -} - -static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN; -} - -static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL; -} - -static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE; -} - -static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED; -} - -static bool llama_is_unused_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED; -} - -static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE); - GGML_ASSERT(llama_is_byte_token(vocab, id)); - const auto & token_data = vocab.id_to_token.at(id); - switch (llama_vocab_get_type(vocab)) { - case LLAMA_VOCAB_TYPE_SPM: - case LLAMA_VOCAB_TYPE_UGM: { - auto buf = token_data.text.substr(3, 2); - return strtol(buf.c_str(), NULL, 16); - } - case LLAMA_VOCAB_TYPE_BPE: { - GGML_ABORT("fatal error"); - //return unicode_utf8_to_byte(token_data.text); // TODO: why is this here after GGML_ASSERT? - } - case LLAMA_VOCAB_TYPE_WPM: { - GGML_ABORT("fatal error"); - } - default: - GGML_ABORT("fatal error"); - } -} - -static void llama_escape_whitespace(std::string & text) { - replace_all(text, " ", "\xe2\x96\x81"); -} - -static void llama_unescape_whitespace(std::string & word) { - replace_all(word, "\xe2\x96\x81", " "); -} +struct llm_tokenizer { + llm_tokenizer() {} + virtual ~llm_tokenizer() = default; +}; struct llm_symbol { using index = int; @@ -190,10 +106,14 @@ struct llm_bigram_spm { size_t size; }; -struct llm_tokenizer_spm { - llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {} +struct llm_tokenizer_spm : llm_tokenizer { + llm_tokenizer_spm(const llama_vocab & /*vocab*/) {} +}; + +struct llm_tokenizer_spm_session { + llm_tokenizer_spm_session(const llama_vocab & vocab) : vocab(vocab) {} - void tokenize(const std::string & text, std::vector & output) { + void tokenize(const std::string & text, std::vector & output) { // split string into utf8 chars int index = 0; size_t offs = 0; @@ -210,7 +130,7 @@ struct llm_tokenizer_spm { } // seed the work queue with all possible 2-character tokens. - for (size_t i = 1; i < symbols.size(); ++i) { + for (int i = 1; i < (int) symbols.size(); ++i) { try_add_bigram(i - 1, i); } @@ -252,13 +172,13 @@ struct llm_tokenizer_spm { } private: - void resegment(llm_symbol & symbol, std::vector & output) { + void resegment(llm_symbol & symbol, std::vector & output) { auto text = std::string(symbol.text, symbol.n); - auto token = vocab.token_to_id.find(text); + auto token = vocab.text_to_token(text); // Do we need to support is_unused? - if (token != vocab.token_to_id.end()) { - output.push_back((*token).second); + if (token != LLAMA_TOKEN_NULL) { + output.push_back(token); return; } @@ -268,13 +188,13 @@ struct llm_tokenizer_spm { // output any symbols that did not form tokens as bytes. output.reserve(output.size() + symbol.n); for (int j = 0; j < (int)symbol.n; ++j) { - llama_vocab::id token_id = llama_byte_to_token_impl(vocab, symbol.text[j]); - output.push_back(token_id); + llama_token id = vocab.byte_to_token(symbol.text[j]); + output.push_back(id); } return; } - resegment(symbols[p->second.first], output); + resegment(symbols[p->second.first], output); resegment(symbols[p->second.second], output); } @@ -282,19 +202,18 @@ struct llm_tokenizer_spm { if (left == -1 || right == -1) { return; } - const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n); - auto token = vocab.token_to_id.find(text); + auto token = vocab.text_to_token(text); - if (token == vocab.token_to_id.end()) { + if (token == LLAMA_TOKEN_NULL) { return; } - if (static_cast((*token).second) >= vocab.id_to_token.size()) { + if (static_cast(token) >= vocab.n_tokens()) { return; } - const auto & tok_data = vocab.id_to_token[(*token).second]; + const auto & tok_data = vocab.get_token_data(token); llm_bigram_spm bigram; bigram.left = left; @@ -309,10 +228,11 @@ struct llm_tokenizer_spm { } const llama_vocab & vocab; + // currently unused + // const llm_tokenizer_spm * spm_tokenizer; std::vector symbols; llm_bigram_spm::queue work_queue; - std::map> rev_merge; }; @@ -324,6 +244,21 @@ struct llm_tokenizer_spm { // TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused +template, typename Compare = std::less> +class llama_priority_queue : public std::priority_queue { +public: + using std::priority_queue::priority_queue; + + T pop_move() { + T item = std::move(this->c.front()); + std::pop_heap(this->c.begin(), this->c.end(), this->comp); + this->c.pop_back(); + return item; + } + + void pop() = delete; +}; + struct llm_bigram_bpe { struct comparator { bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const { @@ -332,7 +267,7 @@ struct llm_bigram_bpe { }; using queue_storage = std::vector; - using queue = std::priority_queue; + using queue = llama_priority_queue; llm_symbol::index left; llm_symbol::index right; std::string text; @@ -340,10 +275,10 @@ struct llm_bigram_bpe { size_t size; }; -struct llm_tokenizer_bpe { - llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) { - GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE); - switch (vocab.type_pre) { +struct llm_tokenizer_bpe : llm_tokenizer { + llm_tokenizer_bpe(const llama_vocab & vocab) { + GGML_ASSERT(vocab.get_type() == LLAMA_VOCAB_TYPE_BPE); + switch (vocab.get_pre_type()) { case LLAMA_VOCAB_PRE_TYPE_LLAMA3: regex_exprs = { // original regex from tokenizer.json @@ -371,6 +306,7 @@ struct llm_tokenizer_bpe { }; break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: + case LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE: regex_exprs = { "\\p{N}{1,3}", "[一-龥぀-ゟ゠-ヿ]+", @@ -393,25 +329,13 @@ struct llm_tokenizer_bpe { "[0-9][0-9][0-9]", }; break; - case LLAMA_VOCAB_PRE_TYPE_FALCON_3: - regex_exprs = { - "[\\p{P}\\$\\+<=>\\^~\\|`]+", - "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", - "[0-9]", - }; - break; - case LLAMA_VOCAB_PRE_TYPE_FALCON_E: - regex_exprs = { - "[\\p{P}\\$\\+<=>\\^~\\|`]+", - "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", - "[0-9]", - }; - break; case LLAMA_VOCAB_PRE_TYPE_STARCODER: case LLAMA_VOCAB_PRE_TYPE_REFACT: case LLAMA_VOCAB_PRE_TYPE_COMMAND_R: case LLAMA_VOCAB_PRE_TYPE_SMOLLM: case LLAMA_VOCAB_PRE_TYPE_CODESHELL: + case LLAMA_VOCAB_PRE_TYPE_EXAONE: + case LLAMA_VOCAB_PRE_TYPE_MINERVA: regex_exprs = { "\\p{N}", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", @@ -421,6 +345,7 @@ struct llm_tokenizer_bpe { case LLAMA_VOCAB_PRE_TYPE_MPT: case LLAMA_VOCAB_PRE_TYPE_OLMO: case LLAMA_VOCAB_PRE_TYPE_JAIS: + case LLAMA_VOCAB_PRE_TYPE_TRILLION: regex_exprs = { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }; @@ -435,6 +360,8 @@ struct llm_tokenizer_bpe { }; break; case LLAMA_VOCAB_PRE_TYPE_PORO: + case LLAMA_VOCAB_PRE_TYPE_BLOOM: + case LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH: regex_exprs = { " ?[^(\\s|.,!?…。,、।۔،)]+", }; @@ -457,6 +384,20 @@ struct llm_tokenizer_bpe { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_CHAMELEON: + // Note: in theory, the special token (sentinel and image token) regex_exprs below + // are unnecessary, as they are split in `tokenizer_st_partition` anyway. + // However, since the upstream pre-tokenizer uses them, they are also + // included here (see https://huggingface.co/facebook/chameleon-7b). + regex_exprs = { + "", // Sentinel tokens + "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens + "([\\t\\n]| | )", // directly from tokenizer.json + "\\p{N}", // Individual digits + "[\\p{P}!-/:-@\\[-`{-~]", // Punctuation, Isolated + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }; + break; case LLAMA_VOCAB_PRE_TYPE_GPT4O: regex_exprs = { // original regex from tokenizer.json @@ -485,7 +426,7 @@ struct llm_tokenizer_bpe { "'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", }; break; - case LLAMA_VOCAB_PRE_TYPE_SEED_CODER: + case LLAMA_VOCAB_PRE_TYPE_SEED_CODER: regex_exprs = { // original regex from tokenizer.json // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\r\n]+|\\s*[\r\n]+|\\s+(?!\\S)|\\s+" @@ -504,36 +445,42 @@ struct llm_tokenizer_bpe { } } - void append(const llama_vocab::id token_id, std::vector & output) const { + std::vector regex_exprs; +}; + +struct llm_tokenizer_bpe_session { + llm_tokenizer_bpe_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + + static void append(const llama_token token_id, std::vector & output) { output.push_back(token_id); } - bool append_bos(std::vector & output) const { - if (vocab.tokenizer_add_bos) { - GGML_ASSERT(vocab.special_bos_id != -1); - output.push_back(vocab.special_bos_id); + bool append_bos(std::vector & output) const { + if (vocab.get_add_bos()) { + GGML_ASSERT(vocab.token_bos() != LLAMA_TOKEN_NULL); + output.push_back(vocab.token_bos()); return true; } return false; } - bool append_eos(std::vector & output) const { - if (vocab.tokenizer_add_eos) { - GGML_ASSERT(vocab.special_eos_id != -1); - output.push_back(vocab.special_eos_id); + bool append_eos(std::vector & output) const { + if (vocab.get_add_eos()) { + GGML_ASSERT(vocab.token_eos() != LLAMA_TOKEN_NULL); + output.push_back(vocab.token_eos()); return true; } return false; } - void check_double_bos_eos(const std::vector & output) const { - if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) { + void check_double_bos_eos(const std::vector & output) const { + if (vocab.get_add_bos() && output.size() >= 2 && output[1] == vocab.token_bos()) { LLAMA_LOG_WARN( "%s: Added a BOS token to the prompt as specified by the model but the prompt " "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " "Are you sure this is what you want?\n", __FUNCTION__); } - if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) { + if (vocab.get_add_eos() && output.size() >= 2 && *(output.end()-2) == vocab.token_eos()) { LLAMA_LOG_WARN( "%s: Added a EOS token to the prompt as specified by the model but the prompt " "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. " @@ -541,21 +488,21 @@ struct llm_tokenizer_bpe { } } - void tokenize(const std::string & text, std::vector & output) { + void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; - - const auto word_collection = unicode_regex_split(text, regex_exprs); + const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs); symbols_final.clear(); - for (auto & word : word_collection) { + for (const auto & word : word_collection) { work_queue = llm_bigram_bpe::queue(); symbols.clear(); int index = 0; size_t offset = 0; - if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) { + //if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) { + if (vocab.get_ignore_merges() && vocab.text_to_token(word) != LLAMA_TOKEN_NULL) { symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); offset = word.size(); } @@ -571,14 +518,13 @@ struct llm_tokenizer_bpe { index++; symbols.emplace_back(sym); } - for (size_t i = 1; i < symbols.size(); ++i) { + for (int i = 1; i < (int) symbols.size(); ++i) { add_new_bigram(i - 1, i); } // build token(s) while (!work_queue.empty()) { - auto bigram = work_queue.top(); - work_queue.pop(); + auto bigram = work_queue.pop_move(); auto & left_symbol = symbols[bigram.left]; auto & right_symbol = symbols[bigram.right]; @@ -630,18 +576,18 @@ struct llm_tokenizer_bpe { } const std::string str = std::string(symbol.text, symbol.n); - const auto token = vocab.token_to_id.find(str); + const auto token = vocab.text_to_token(str); - if (token == vocab.token_to_id.end()) { + if (token == LLAMA_TOKEN_NULL) { for (auto j = str.begin(); j != str.end(); ++j) { std::string byte_str(1, *j); - auto token_multibyte = vocab.token_to_id.find(byte_str); - if (token_multibyte != vocab.token_to_id.end()) { - output.push_back(token_multibyte->second); + auto token_multibyte = vocab.text_to_token(byte_str); + if (token_multibyte != LLAMA_TOKEN_NULL) { + output.push_back(token_multibyte); } } } else { - output.push_back((*token).second); + output.push_back(token); } } } @@ -652,7 +598,6 @@ struct llm_tokenizer_bpe { if (left == -1 || right == -1) { return; } - std::string left_token = std::string(symbols[left].text, symbols[left].n); std::string right_token = std::string(symbols[right].text, symbols[right].n); @@ -676,12 +621,10 @@ struct llm_tokenizer_bpe { } const llama_vocab & vocab; - - std::vector regex_exprs; + const llm_tokenizer_bpe & tokenizer; std::vector symbols; std::vector symbols_final; - llm_bigram_bpe::queue work_queue; }; @@ -689,15 +632,16 @@ struct llm_tokenizer_bpe { // WPM tokenizer // -struct llm_tokenizer_wpm { - llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {} +struct llm_tokenizer_wpm : llm_tokenizer { + llm_tokenizer_wpm(const llama_vocab & /*vocab*/) {} +}; - void tokenize(const std::string & text, std::vector & output) const { - const auto & token_map = vocab.token_to_id; +struct llm_tokenizer_wpm_session { + llm_tokenizer_wpm_session(const llama_vocab & vocab) : vocab(vocab) {} + void tokenize(const std::string & text, std::vector & output) { // normalize and split by whitespace std::vector words = preprocess(text); - // bos token prepended already // find the longest tokens that form the words @@ -718,10 +662,10 @@ struct llm_tokenizer_wpm { for (int i = 0; i < n; ++i) { // loop through possible match length bool match = false; - for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) { - auto it = token_map.find(word1.substr(i, j - i)); - if (it != token_map.end()) { - output.push_back(it->second); + for (int j = std::min(n, i + vocab.max_token_len() + 1); j > i; j--) { + auto id = vocab.text_to_token(word1.substr(i, j - i)); + if (id != LLAMA_TOKEN_NULL) { + output.push_back(id); match = true; i = j - 1; break; @@ -736,18 +680,18 @@ struct llm_tokenizer_wpm { // we didn't find any matches for this word if (current_tokens == output.size()) { - output.push_back(vocab.special_unk_id); + output.push_back(vocab.token_unk()); } } } // TODO: reduce string copies by using cpts_offs array - std::vector preprocess(const std::string & text) const { + static std::vector preprocess(const std::string & text) { const std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); std::vector words(1, ""); for (const uint32_t cpt : cpts_nfd) { - const auto flags = unicode_cpt_flags(cpt); + const auto flags = unicode_cpt_flags_from_cpt(cpt); if (flags.is_whitespace) { if (words.back().size()) { // finish previous word if any @@ -794,53 +738,56 @@ struct llm_tokenizer_wpm { //(cpt >= 0xFF00 && cpt <= 0xFFEF); } +private: const llama_vocab & vocab; + // currently unused + // const llm_tokenizer_wpm * wpm_tokenizer; }; // // UGM tokenizer // -struct llm_tokenizer_ugm { - llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) { - if (vocab.precompiled_charsmap.size() > 0) { +struct llm_tokenizer_ugm : llm_tokenizer { + llm_tokenizer_ugm(const llama_vocab & vocab, const std::vector & precompiled_charsmap) { + if (precompiled_charsmap.size() > 0) { size_t charsmap_offset = 0; // First four bytes of precompiled_charsmap contains length of binary // blob containing XOR-compressed compact double array (XCDA) entries - uint32_t xcda_blob_size = *(const uint32_t *) &vocab.precompiled_charsmap[0]; + uint32_t xcda_blob_size = *(const uint32_t *) &precompiled_charsmap[0]; charsmap_offset += sizeof(xcda_blob_size); - if (xcda_blob_size + charsmap_offset >= vocab.precompiled_charsmap.size()) { + if (xcda_blob_size + charsmap_offset >= precompiled_charsmap.size()) { throw std::runtime_error("Index out of array bounds in precompiled charsmap!"); } // Next xcda_blob_size bytes contain entries of XOR-compressed compact // double array (XCDA). Each entry is bit-packed into a 32-bit integer. - xcda_array = (const uint32_t *) &vocab.precompiled_charsmap[charsmap_offset]; + xcda_array = (const uint32_t *) &precompiled_charsmap[charsmap_offset]; xcda_array_size = xcda_blob_size / sizeof(uint32_t); charsmap_offset += xcda_blob_size; // Remaining bytes of precompiled charsmap contain null-terminated // replacement strings for prefixes matched by the XCDA. - prefix_replacements = &vocab.precompiled_charsmap[charsmap_offset]; - prefix_replacements_size = vocab.precompiled_charsmap.size() - charsmap_offset; + prefix_replacements = &precompiled_charsmap[charsmap_offset]; + prefix_replacements_size = precompiled_charsmap.size() - charsmap_offset; } - for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) { - const auto &token_data = vocab.id_to_token[id]; + for (uint32_t id = 0; id < vocab.n_tokens(); ++id) { + const auto & token_data = vocab.get_token_data(id); - if (llama_is_normal_token(vocab, id)) { + if (vocab.is_normal(id)) { min_score = std::min(min_score, token_data.score); max_score = std::max(max_score, token_data.score); } - if (llama_is_normal_token(vocab, id) || - llama_is_user_defined_token(vocab, id) || - llama_is_unused_token(vocab, id)) { + if (vocab.is_normal(id) || + vocab.is_user_defined(id) || + vocab.is_unused(id)) { token_matcher.insert(token_data.text.data(), token_data.text.size(), id); } - if (llama_is_user_defined_token(vocab, id)) { + if (vocab.is_user_defined(id)) { user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size()); } } @@ -848,6 +795,29 @@ struct llm_tokenizer_ugm { unknown_token_score = min_score - unknown_token_score_penalty; } + // escaped space symbol - U+2581 (Lower One Eighth Block) + const std::string escaped_space = "\xE2\x96\x81"; + + const char * prefix_replacements = NULL; + size_t prefix_replacements_size = 0; + + const uint32_t * xcda_array = NULL; + size_t xcda_array_size = 0; + + struct naive_trie user_defined_token_matcher; + + float min_score = FLT_MAX; + float max_score = -FLT_MAX; + + float unknown_token_score_penalty = 10.0; + float unknown_token_score; + + struct naive_trie token_matcher; +}; + +struct llm_tokenizer_ugm_session { + llm_tokenizer_ugm_session(const llama_vocab & vocab, const llm_tokenizer_ugm & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + /* This implementation is based on SentencePiece optimized Viterbi algorithm for * unigram language models. The general idea is to: * - move along the input sequence in steps of one UTF code point, @@ -861,7 +831,7 @@ struct llm_tokenizer_ugm { * After processing the whole sequence we backtrack from the end to get * the best tokenization. */ - void tokenize(const std::string & text, std::vector & output) { + void tokenize(const std::string & text, std::vector & output) { // get current size of output (for reversal later) size_t output_size = output.size(); @@ -874,9 +844,9 @@ struct llm_tokenizer_ugm { } // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores - std::vector tokenization_results(input_len + 1, {vocab.special_unk_id, 0, -FLT_MAX}); + std::vector tokenization_results(input_len + 1, {vocab.token_unk(), 0, -DBL_MAX}); // at the beginning tokenization score is zero - tokenization_results[0] = { vocab.special_unk_id, 0, 0 }; + tokenization_results[0] = { vocab.token_unk(), 0, 0 }; for (size_t input_offset = 0; input_offset < input_len;) { size_t prefix_offset = input_offset; @@ -886,7 +856,7 @@ struct llm_tokenizer_ugm { // traverse the token matcher trie to find a matching token bool single_codepoint_token_found = false; const struct best_tokenization & current_best = tokenization_results[input_offset]; - struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]); + const struct naive_trie * node = tokenizer.token_matcher.traverse(normalized[prefix_offset++]); while (prefix_offset <= input_len && node != NULL) { // check if we found valid token in prefix @@ -896,17 +866,17 @@ struct llm_tokenizer_ugm { single_codepoint_token_found = true; } llama_token token_id = node->value; - const auto & token_data = vocab.id_to_token[token_id]; + const auto & token_data = vocab.get_token_data(token_id); // we set the user-defined token scores to 0 to make them more likely to be selected // (normal token scores are log probabilities, so they are negative) // score type is double here to make tokenization results exactly // the same as in the HF tokenizer using SentencePiece - const double token_score = llama_is_user_defined_token(vocab, token_id) ? 0.0 : token_data.score; + const double token_score = vocab.is_user_defined(token_id) ? 0.0 : token_data.score; const double challenger_score = current_best.score_sum + token_score; struct best_tokenization & current_champ = tokenization_results[prefix_offset]; if (challenger_score > current_champ.score_sum) { - struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score }; + struct best_tokenization challenger = { token_id, input_offset, challenger_score }; current_champ = challenger; } } @@ -916,11 +886,11 @@ struct llm_tokenizer_ugm { // if we didn't find a valid token corresponding to the whole UTF code point // then use unknown token as the tokenization of this UTF code point if (!single_codepoint_token_found) { - const double challenger_score = current_best.score_sum + unknown_token_score; + const double challenger_score = current_best.score_sum + tokenizer.unknown_token_score; prefix_offset = input_offset + n_utf8_code_units; struct best_tokenization & current_champ = tokenization_results[prefix_offset]; if (challenger_score > current_champ.score_sum) { - struct best_tokenization challenger = { vocab.special_unk_id, input_offset, (float) challenger_score }; + struct best_tokenization challenger = { vocab.token_unk(), input_offset, challenger_score }; current_champ = challenger; } } @@ -933,7 +903,7 @@ struct llm_tokenizer_ugm { // merge sequences of consecutive unknown tokens into single unknown tokens bool is_prev_unknown = false; for (struct best_tokenization & tokenization = tokenization_results[input_len]; ; tokenization = tokenization_results[tokenization.input_offset]) { - bool is_unknown = tokenization.token_id == vocab.special_unk_id; + bool is_unknown = tokenization.token_id == vocab.token_unk(); if (!(is_prev_unknown && is_unknown)) { output.push_back(tokenization.token_id); } @@ -948,7 +918,6 @@ struct llm_tokenizer_ugm { } private: - const llama_vocab & vocab; // helper structure for returning normalization results struct normalization_result { @@ -961,11 +930,11 @@ struct llm_tokenizer_ugm { normalized->clear(); normalized->reserve(input.size() * 3); - const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " "; + const std::string space = vocab.get_escape_whitespaces() ? tokenizer.escaped_space : " "; - bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; - bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; - bool shall_merge_spaces = vocab.tokenizer_remove_extra_whitespaces; + const bool shall_prepend_space = !vocab.get_treat_whitespace_as_suffix() && vocab.get_add_space_prefix(); + const bool shall_append_space = vocab.get_treat_whitespace_as_suffix() && vocab.get_add_space_prefix(); + const bool shall_merge_spaces = vocab.get_remove_extra_whitespaces(); bool is_space_prepended = false; bool processing_non_ws = false; @@ -1006,7 +975,7 @@ struct llm_tokenizer_ugm { /* * This structure is a view wrapper for XOR-compressed double array (XCDA) * See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries. - * Eeach bit-packed entry contains: + * Each bit-packed entry contains: * - BASE array value in bits 10-30 * - LCHECK array value in bits 0-7 * - LEAF array value in bit 9 @@ -1043,13 +1012,21 @@ struct llm_tokenizer_ugm { size_t xcda_array_size; }; + // this structure stores the best tokenization so far at input_offset + struct best_tokenization { + llama_token token_id; + size_t input_offset; + double score_sum; + }; + struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) { if (input_offset == input.size()) { return { &input[input_offset], 0, 0 }; } // if input prefix matches some user-defined token return this token as normalization result - auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset); + auto user_defined_token_match = + tokenizer.user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset); if (user_defined_token_match.second > 0) { return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second }; } @@ -1057,8 +1034,8 @@ struct llm_tokenizer_ugm { size_t longest_prefix_length = 0; size_t longest_prefix_offset = 0; - if (xcda_array_size > 0) { - struct xcda_array_view xcda_view(xcda_array, xcda_array_size); + if (tokenizer.xcda_array_size > 0) { + struct xcda_array_view xcda_view(tokenizer.xcda_array, tokenizer.xcda_array_size); // Find the longest normalized sequence matching the input prefix by walking // the XOR-compressed compact double array (XCDA) starting from the root node @@ -1094,722 +1071,2720 @@ struct llm_tokenizer_ugm { if (longest_prefix_length > 0) { // we have a match, so return the replacement sequence - if (longest_prefix_offset >= prefix_replacements_size) { + if (longest_prefix_offset >= tokenizer.prefix_replacements_size) { throw std::runtime_error("Index out of array bounds in precompiled charsmap!"); } - const char * prefix_replacement = &prefix_replacements[longest_prefix_offset]; + const char * prefix_replacement = &(tokenizer.prefix_replacements)[longest_prefix_offset]; return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length }; - } else { - // check if the input prefix contains a valid sequence of UTF-8 code units - try { - // if yes, return this sequence unmodified - size_t prefix_offset = input_offset; - unicode_cpt_from_utf8(input, prefix_offset); - return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset }; - } catch (std::invalid_argument & /*ex*/) { - // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER - return { "\xEF\xBF\xBD", 3, 1 }; - } } - } - // escaped space symbol - U+2581 (Lower One Eighth Block) - const std::string escaped_space = "\xE2\x96\x81"; + // check if the input prefix contains a valid sequence of UTF-8 code units + try { + // if yes, return this sequence unmodified + size_t prefix_offset = input_offset; + unicode_cpt_from_utf8(input, prefix_offset); + return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset }; + } catch (std::invalid_argument & /*ex*/) { + // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER + return { "\xEF\xBF\xBD", 3, 1 }; + } + } - const char * prefix_replacements = NULL; - size_t prefix_replacements_size = 0; + const llama_vocab & vocab; + const llm_tokenizer_ugm & tokenizer; +}; - const uint32_t * xcda_array = NULL; - size_t xcda_array_size = 0; +// +// RWKV tokenizer +// - struct naive_trie user_defined_token_matcher; +static std::vector llama_unescape_rwkv_token(const std::string & escaped) { + std::vector output; + output.reserve(escaped.size()); + + // Parser state + bool escaping = false; + uint8_t hex_remaining = 0; + uint8_t hex_acc = 0; + + // Step through characters, performing parsing + for (const char & c : escaped) { + // If we're parsing a hex code, interpret the next character + if (hex_remaining != 0) { + uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0'); + hex_acc = (hex_acc << 4) + value; + + hex_remaining -= 1; + if (hex_remaining == 0) { + output.push_back(hex_acc); + hex_acc = 0; + } - // this structure stores the best tokenization so far at input_offset - struct best_tokenization { - llama_token token_id; - size_t input_offset; - float score_sum; - }; + continue; + } - float min_score = FLT_MAX; - float max_score = -FLT_MAX; + // If we got an escape character, interpret it + if (escaping) { + if (c == 't') { + output.push_back('\t'); + } else if (c == 'n') { + output.push_back('\n'); + } else if (c == 'r') { + output.push_back('\r'); + } else if (c == 'x') { + hex_remaining = 2; + } else { + output.push_back(c); + } - float unknown_token_score_penalty = 10.0; - float unknown_token_score; + escaping = false; + continue; + } - struct naive_trie token_matcher; -}; + if (c == '\\') { + escaping = true; + continue; + } -// -// (de-) tokenize -// + output.push_back(c); + } -typedef enum FRAGMENT_BUFFER_VARIANT_TYPE { - FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN, - FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT -} FRAGMENT_BUFFER_VARIANT_TYPE; + return output; +} -struct fragment_buffer_variant { - fragment_buffer_variant(llama_vocab::id _token) - : - type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN), - token(_token), - raw_text(_dummy), - offset(0), - length(0) {} +struct llm_tokenizer_rwkv : llm_tokenizer { + llm_tokenizer_rwkv(const llama_vocab & vocab) { + // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens. + // For now, we decode the vocab here into the lookup we'll use for tokenization. - fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length) - : - type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT), - token((llama_vocab::id) - 1), - raw_text(_raw_text), - offset(_offset), - length(_length){ - GGML_ASSERT(_offset >= 0); - GGML_ASSERT(_length >= 1); - GGML_ASSERT(offset + length <= raw_text.length()); + // build trie + for (uint32_t id = 0; id < vocab.n_tokens(); ++id) { + const auto & data = vocab.get_token_data(id); + const auto text = llama_unescape_rwkv_token(data.text); + token_matcher.insert((const char *) text.data(), text.size(), id); } + } - const FRAGMENT_BUFFER_VARIANT_TYPE type; - const llama_vocab::id token; - const std::string _dummy; - const std::string & raw_text; - const uint64_t offset; - const uint64_t length; + struct naive_trie token_matcher; }; -// #define PRETOKENIZERDEBUG +struct llm_tokenizer_rwkv_session { + llm_tokenizer_rwkv_session(const llama_vocab & vocab, const llm_tokenizer_rwkv & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + + void tokenize(const std::string & text, std::vector & output) { + uint32_t position = 0; + while (position < text.size()) { + const struct naive_trie * node = tokenizer.token_matcher.traverse(text[position]); + if (node == NULL) { + // no matching token found, add unknown token + output.push_back(vocab.token_unk()); + position += 1; + continue; + } -static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer, bool parse_special) { - // for each special token - for (const llama_vocab::id special_id : vocab.cache_special_tokens) { - const auto & data = vocab.id_to_token[special_id]; - const auto & special_token = data.text; + // traverse the trie to find the longest matching token + uint32_t token_id = 0; + uint32_t token_length = 0; + while (node != NULL) { + if (node->has_value) { + token_id = node->value; + token_length = position + 1; + } + node = node->traverse(text[++position]); + } - if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) { - // Ignore control and unknown tokens when parse_special == false - continue; - // User-defined tokens are still pre-tokenized before everything else - // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726 - // This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.) + // add the longest matching token + output.push_back(token_id); + position = token_length; } + } - // for each text fragment - std::forward_list::iterator it = buffer.begin(); - while (it != buffer.end()) { - auto & fragment = (*it); +private: + const llama_vocab & vocab; + const llm_tokenizer_rwkv & tokenizer; +}; - // if a fragment is text ( not yet processed ) - if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { - auto & raw_text = fragment.raw_text; +struct llm_tokenizer_plamo2 : llm_tokenizer { + llm_tokenizer_plamo2(const llama_vocab & vocab) { + build(vocab); + } - auto raw_text_base_offset = fragment.offset; - auto raw_text_base_length = fragment.length; + void build(const llama_vocab & vocab) { + // Reset internal structures + tokens_.clear(); + bytes_.assign(256, 0); + to_suffix_id_.clear(); + table_.clear(); + + // Build token list and byte mapping + std::unordered_map suffix_to_score; + std::unordered_map token_to_id; + + for (size_t token_id = 0; token_id < vocab.n_tokens(); ++token_id) { + const auto & entry = vocab.get_token_data(token_id); + tokens_.push_back(entry.text); + token_to_id[entry.text] = static_cast(token_id); + + // Handle byte tokens + if (vocab.is_byte(token_id)) { + if (entry.text.length() == 6 && entry.text.substr(0, 3) == "<0x" && entry.text.back() == '>') { + std::string hex_str = entry.text.substr(3, 2); + int byte_val = std::stoi(hex_str, nullptr, 16); + bytes_[byte_val] = static_cast(token_id); + } + continue; + } - // loop over the text - while (true) { - // find the first occurrence of a given special token in this fragment - // passing offset argument only limit the "search area" but match coordinates - // are still relative to the source full raw_text - auto match = raw_text.find(special_token, raw_text_base_offset); + // Add token and all its suffixes to suffix_to_score + suffix_to_score[entry.text] = entry.score; - // no occurrences found, stop processing this fragment for a given special token - if (match == std::string::npos) break; + // Extract suffixes character by character (UTF-8 aware) + std::vector cpts = unicode_cpts_from_utf8(entry.text); + for (size_t i = 1; i < cpts.size(); ++i) { + std::string suffix; + for (size_t j = i; j < cpts.size(); ++j) { + suffix += unicode_cpt_to_utf8(cpts[j]); + } + if (suffix_to_score.find(suffix) == suffix_to_score.end()) { + suffix_to_score[suffix] = std::numeric_limits::quiet_NaN(); + } + } + } - // check if match is within bounds of offset <-> length - if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break; + // Check that all byte tokens are set + for (int i = 0; i < 256; ++i) { + if (bytes_[i] == 0) { + throw std::runtime_error("Byte token for <0x" + std::to_string(i) + "> is not set"); + } + } -#ifdef PRETOKENIZERDEBUG - LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str()); -#endif - auto source = std::distance(buffer.begin(), it); + // Build suffix list in lexicographical order of reversed strings + std::vector suffixes; + for (const auto & pair : suffix_to_score) { + suffixes.push_back(pair.first); + } + suffixes.push_back(""); // Empty suffix - // if match is further than base offset - // then we have some text to the left of it - if (match > raw_text_base_offset) { - // left - const int64_t left_reminder_offset = raw_text_base_offset + 0; - int64_t left_reminder_length = match - raw_text_base_offset; + std::sort(suffixes.begin(), suffixes.end(), [](const std::string & a, const std::string & b) { + std::string rev_a(a.rbegin(), a.rend()); + std::string rev_b(b.rbegin(), b.rend()); + return rev_a < rev_b; + }); - if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) { - while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) { - left_reminder_length--; - } - } + // Build suffix_to_id and to_suffix_id_ + std::unordered_map suffix_to_id; + int32_t num_pieces = 0; - if (left_reminder_length > 0) { - buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length); - it++; - } + for (const auto & suffix : suffixes) { + suffix_to_id[suffix] = num_pieces; + if (!suffix.empty()) { + std::vector cpts = unicode_cpts_from_utf8(suffix); -#ifdef PRETOKENIZERDEBUG - LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str()); -#endif + std::string remaining; + for (size_t i = 1; i < cpts.size(); ++i) { + remaining += unicode_cpt_to_utf8(cpts[i]); + } + + int64_t piece_code = (static_cast(cpts[0]) << 32) | suffix_to_id[remaining]; + to_suffix_id_[piece_code] = num_pieces; + + // Count number of pieces for this suffix + int32_t pieces_for_suffix = 1; // sentinel row + for (int32_t piece_length = static_cast(cpts.size()); piece_length > 0; --piece_length) { + std::string piece; + for (int32_t i = 0; i < piece_length; ++i) { + piece += unicode_cpt_to_utf8(cpts[i]); } + if (suffix_to_score.find(piece) != suffix_to_score.end()) { + pieces_for_suffix++; + } + } + num_pieces += pieces_for_suffix; + } else { + num_pieces++; // Empty suffix contributes one piece (sentinel row) + } + } - // special token - buffer.emplace_after(it, special_id); - it++; + // Build flattened table + table_.resize(num_pieces, std::vector(4, 0)); + int32_t table_idx = 0; - // right - if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) { - int64_t right_reminder_offset = match + special_token.length(); - int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length()); + for (const auto & suffix : suffixes) { + // Add all prefixes of the suffix to the table (in decreasing order of length) + std::vector cpts = unicode_cpts_from_utf8(suffix); + for (int32_t piece_length = static_cast(cpts.size()); piece_length > 0; --piece_length) { + std::string piece; + for (int32_t i = 0; i < piece_length; ++i) { + piece += unicode_cpt_to_utf8(cpts[i]); + } - if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) { - while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) { - right_reminder_offset++; - right_reminder_length--; - } - } + auto score_it = suffix_to_score.find(piece); + if (score_it == suffix_to_score.end()) { + continue; + } - if (right_reminder_length > 0) { - buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length); - it++; - } + table_[table_idx][TABLE_PIECE_LENGTH] = piece_length; + auto token_it = token_to_id.find(piece); + table_[table_idx][TABLE_TOKEN_ID] = (token_it != token_to_id.end()) ? token_it->second : -1; -#ifdef PRETOKENIZERDEBUG - LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str()); -#endif + float score = score_it->second; + table_[table_idx][TABLE_SCORE] = std::isfinite(score) ? + static_cast(std::round(score * 1e4)) : INVALID_SCORE; + table_[table_idx][TABLE_PIECE_ID] = suffix_to_id[piece]; - if (source == 0) { - buffer.erase_after(buffer.before_begin()); - } else { - buffer.erase_after(std::next(buffer.begin(), (source-1))); + table_idx++; + } + + // Add sentinel row + table_[table_idx][TABLE_PIECE_LENGTH] = 1; + table_[table_idx][TABLE_TOKEN_ID] = -1; + table_[table_idx][TABLE_SCORE] = UNKNOWN_SCORE; + table_idx++; + } + } + + std::vector encode(const std::string & text) const { + std::vector unicode_data = unicode_cpts_from_utf8(text); + // Skip the first code point if it is a BOM (Byte Order Mark) + if (!unicode_data.empty() && unicode_data[0] == 0xFEFF) { + unicode_data.erase(unicode_data.begin()); + } + + if (unicode_data.empty()) { + return {}; + } + + const size_t data_len = unicode_data.size(); + + // Initialize scores array (dynamic programming) + std::vector scores(data_len + 1, static_cast(1) << 60); + scores[data_len] = 0; + + // Path array to track best tokenization + std::vector> path(data_len + 1, std::vector(3, 0)); + + int32_t suffix_id = 0; + + // Process from end to beginning + for (int i = static_cast(data_len) - 1; i >= 0; --i) { + uint32_t c = unicode_data[i]; + + // Find next suffix ID + for (size_t p = suffix_id; p < table_.size(); ++p) { + int64_t piece_code = (static_cast(c) << 32) | table_[p][TABLE_PIECE_ID]; + auto it = to_suffix_id_.find(piece_code); + suffix_id = (it != to_suffix_id_.end()) ? it->second : 0; + + if (suffix_id > 0 || table_[p][TABLE_SCORE] == UNKNOWN_SCORE) { + break; + } + } + + // Update best path + for (size_t p = suffix_id; p < table_.size(); ++p) { + int32_t score = table_[p][TABLE_SCORE]; + if (score > INVALID_SCORE) { + int32_t piece_length = table_[p][TABLE_PIECE_LENGTH]; + int64_t s = scores[i + piece_length] - score; + + if (s < scores[i]) { + scores[i] = s; + path[i][PATH_TOKEN_LENGTH] = piece_length; + path[i][PATH_TOKEN_ID] = table_[p][TABLE_TOKEN_ID]; + path[i][PATH_NUM_TOKENS] = path[i + piece_length][PATH_NUM_TOKENS] + 1; + + if (score == UNKNOWN_SCORE) { + // Add UTF-8 byte count + path[i][PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000); } + } + } - // repeat for the right side - raw_text_base_offset = right_reminder_offset; - raw_text_base_length = right_reminder_length; + if (score == UNKNOWN_SCORE) { + break; + } + } + } -#ifdef PRETOKENIZERDEBUG - LLAMA_LOG_WARN("RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str()); -#endif + // Decode the best path + std::vector token_ids; + token_ids.reserve(path[0][PATH_NUM_TOKENS]); + + int pos = 0; + while (pos < static_cast(data_len)) { + if (path[pos][PATH_TOKEN_ID] >= 0) { + token_ids.push_back(path[pos][PATH_TOKEN_ID]); + } else { + // Fall back to byte tokens + uint32_t c = unicode_data[pos]; + int s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000); + + for (int i = 0; i < s; ++i) { + uint8_t b; + if (s == 1) { + b = c; } else { - if (source == 0) { - buffer.erase_after(buffer.before_begin()); + if (i == 0) { + b = (0xF00 >> s) & 0xFF; } else { - buffer.erase_after(std::next(buffer.begin(), (source-1))); + b = 0x80; } - break; } + token_ids.push_back(bytes_[b | ((c >> ((s - i - 1) * 6)) & 0x3F)]); } } - it++; + + assert(path[pos][PATH_TOKEN_LENGTH] > 0); + pos += path[pos][PATH_TOKEN_LENGTH]; } + + return token_ids; } -} +private: + // Constants for table structure + static constexpr int32_t TABLE_PIECE_LENGTH = 0; + static constexpr int32_t TABLE_TOKEN_ID = 1; + static constexpr int32_t TABLE_SCORE = 2; + static constexpr int32_t TABLE_PIECE_ID = 3; -std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) { - std::vector output; - std::forward_list fragment_buffer; + // Constants for path array + static constexpr int32_t PATH_TOKEN_LENGTH = 0; + static constexpr int32_t PATH_TOKEN_ID = 1; + static constexpr int32_t PATH_NUM_TOKENS = 2; - if (!raw_text.empty()) { - fragment_buffer.emplace_front(raw_text, 0, raw_text.length()); - tokenizer_st_partition(vocab, fragment_buffer, parse_special); + // Score constants + static constexpr int32_t INVALID_SCORE = -20000000; + static constexpr int32_t UNKNOWN_SCORE = -10000000; + + // List of tokens in the vocabulary + std::vector tokens_; + + // Mapping from byte code point to token ID (for byte fallback) + std::vector bytes_; + + // Mapping from piece code to suffix ID + std::unordered_map to_suffix_id_; + + // Flattened table representing the Trie structure + // Each row contains: [piece_length, token_id, score, piece_id] + std::vector> table_; +}; + +struct llm_tokenizer_plamo2_session { + llm_tokenizer_plamo2_session(const llm_tokenizer_plamo2 & tokenizer) : tokenizer(tokenizer) {} + + void tokenize(const std::string & text, std::vector & output) { + std::vector tokens = tokenizer.encode(text); + output.insert(output.end(), tokens.begin(), tokens.end()); } - switch (vocab.type) { - case LLAMA_VOCAB_TYPE_SPM: - { - // OG tokenizer behavior: - // - // tokenizer.encode('', add_special_tokens=True) returns [1] - // tokenizer.encode('', add_special_tokens=False) returns [] +private: + const llm_tokenizer_plamo2 & tokenizer; +}; - bool is_prev_special = true; // prefix with space if first token +// +// impl +// - if (add_special && vocab.tokenizer_add_bos) { - GGML_ASSERT(vocab.special_bos_id != -1); - output.push_back(vocab.special_bos_id); - is_prev_special = true; - } +typedef enum FRAGMENT_BUFFER_VARIANT_TYPE { + FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN, + FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT +} FRAGMENT_BUFFER_VARIANT_TYPE; - for (const auto & fragment : fragment_buffer) { - if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { - auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); +struct fragment_buffer_variant { + fragment_buffer_variant(llama_token _token) + : + type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN), + token(_token), + raw_text(_dummy), + offset(0), + length(0) {} - // prefix with space if previous is special - if (vocab.tokenizer_add_space_prefix && is_prev_special) { - raw_text = " " + raw_text; - } + fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length) + : + type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT), + token((llama_token) - 1), + raw_text(_raw_text), + offset(_offset), + length(_length){ + GGML_ASSERT(_offset >= 0); + GGML_ASSERT(_length >= 1); + GGML_ASSERT(offset + length <= raw_text.length()); + } -#ifdef PRETOKENIZERDEBUG - LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); -#endif - llm_tokenizer_spm tokenizer(vocab); - llama_escape_whitespace(raw_text); - tokenizer.tokenize(raw_text, output); - is_prev_special = false; - } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - output.push_back(fragment.token); - is_prev_special = true; - } - } + const FRAGMENT_BUFFER_VARIANT_TYPE type; + const llama_token token; + const std::string _dummy; + const std::string & raw_text; + const uint64_t offset; + const uint64_t length; +}; - if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) { - LLAMA_LOG_WARN( - "%s: Added a BOS token to the prompt as specified by the model but the prompt " - "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " - "Are you sure this is what you want?\n", __FUNCTION__); - } +struct llama_vocab::impl { + uint32_t n_token_types = 0; // for BERT-style token types + + std::string tokenizer_model; + std::string tokenizer_pre; + + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; + enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + + int max_token_len = 0; // used for optimizing longest token search + + // default LLaMA special tokens + // TODO: should we set all of these to LLAMA_TOKEN_NULL? + llama_token special_bos_id = 1; + llama_token special_eos_id = 2; + llama_token special_eot_id = LLAMA_TOKEN_NULL; + llama_token special_eom_id = LLAMA_TOKEN_NULL; + llama_token special_unk_id = 0; + llama_token special_sep_id = LLAMA_TOKEN_NULL; + llama_token special_pad_id = LLAMA_TOKEN_NULL; + llama_token special_mask_id = LLAMA_TOKEN_NULL; + + llama_token linefeed_id = 13; + + // fim tokens + llama_token special_fim_pre_id = LLAMA_TOKEN_NULL; + llama_token special_fim_suf_id = LLAMA_TOKEN_NULL; + llama_token special_fim_mid_id = LLAMA_TOKEN_NULL; + llama_token special_fim_pad_id = LLAMA_TOKEN_NULL; + llama_token special_fim_rep_id = LLAMA_TOKEN_NULL; // repo + llama_token special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator + + // tokenizer flags + bool add_space_prefix = false; + bool add_bos = false; + bool add_eos = false; + bool add_sep = false; + bool ignore_merges = false; + bool clean_spaces = false; // clean_up_tokenization_spaces + bool remove_extra_whitespaces = false; + bool escape_whitespaces = true; + bool treat_whitespace_as_suffix = false; + + std::unordered_map token_to_id; + std::vector id_to_token; + + std::vector cache_special_tokens; + std::vector cache_token_to_piece; // llama_token_to_piece(special = true); + struct pair_hash { + size_t operator()(const std::pair & p) const { + return std::hash{}(p.first) ^ //create some hash for pair + (std::hash{}(p.second) << 1); + } + }; + std::unordered_map, int, pair_hash> bpe_ranks; - if (add_special && vocab.tokenizer_add_eos) { - GGML_ASSERT(vocab.special_eos_id != -1); - output.push_back(vocab.special_eos_id); - } - } break; - case LLAMA_VOCAB_TYPE_BPE: - { - llm_tokenizer_bpe tokenizer(vocab); + // set of all tokens that cause "end of generation" + std::set special_eog_ids; - if (add_special) { - tokenizer.append_bos(output); - } - for (const auto & fragment : fragment_buffer) { - if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { - auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); + std::unique_ptr tokenizer; -#ifdef PRETOKENIZERDEBUG - LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); -#endif - tokenizer.tokenize(raw_text, output); - } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - tokenizer.append(fragment.token, output); - } - } + std::vector precompiled_charsmap; - if (add_special) { - tokenizer.append_eos(output); - tokenizer.check_double_bos_eos(output); + impl(const llama_vocab & vocab) : vocab(vocab) { + } + + ~impl() = default; + + void load(llama_model_loader & ml, const LLM_KV & kv); + + enum llama_vocab_type get_type() const; + + std::string type_name() const; + + bool is_normal (llama_token id) const; + bool is_unknown (llama_token id) const; + bool is_control (llama_token id) const; + bool is_byte (llama_token id) const; + bool is_user_defined(llama_token id) const; + bool is_unused (llama_token id) const; + bool is_eog (llama_token id) const; + + uint8_t token_to_byte(llama_token id) const; + + llama_token_attr token_get_attr(llama_token id) const; + + void init_tokenizer(enum llama_vocab_type type); + + void tokenizer_st_partition(std::forward_list & buffer, bool parse_special) const; + + std::string token_to_piece_for_cache( + llama_token token, + bool special) const; + + + std::vector tokenize( + const std::string & raw_text, + bool add_special, + bool parse_special = false) const; + + int32_t tokenize( + const char * text, + int32_t text_len, + llama_token * tokens, + int32_t n_tokens_max, + bool add_special, + bool parse_special) const; + + // does not write null-terminator to buf + int32_t token_to_piece( + llama_token token, + char * buf, + int32_t length, + int32_t lstrip, + bool special) const; + + // use cached data + const std::string & token_to_piece(llama_token token) const; + + int32_t detokenize( + const llama_token * tokens, + int32_t n_tokens, + char * text, + int32_t text_len_max, + bool remove_special, + bool unparse_special) const; + + std::string detokenize( + const std::vector & tokens, + bool special) const; + + void print_info() const; + +private: + const llama_vocab & vocab; +}; + +void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { + gguf_context * ctx = ml.meta; + + // determine vocab type + { + ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); + ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); + + ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, n_token_types, false); + + if (tokenizer_model == "no_vocab" || tokenizer_model == "none") { + type = LLAMA_VOCAB_TYPE_NONE; + + // default special tokens + special_bos_id = LLAMA_TOKEN_NULL; + special_eos_id = LLAMA_TOKEN_NULL; + special_unk_id = LLAMA_TOKEN_NULL; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + special_mask_id = LLAMA_TOKEN_NULL; + linefeed_id = LLAMA_TOKEN_NULL; + + // read vocab size from metadata + uint32_t n_tokens = 0; + if (ml.get_key(LLM_KV_VOCAB_SIZE, n_tokens, false)) { + LLAMA_LOG_WARN("%s: adding %u dummy tokens\n", __func__, n_tokens); + id_to_token.resize(n_tokens); + } + + return; + } + + if (tokenizer_model == "llama") { + type = LLAMA_VOCAB_TYPE_SPM; + + // default special tokens + special_bos_id = 1; + special_eos_id = 2; + special_unk_id = 0; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + special_mask_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "bert") { + type = LLAMA_VOCAB_TYPE_WPM; + + // default special tokens + special_bos_id = 101; + special_eos_id = LLAMA_TOKEN_NULL; + special_unk_id = 100; + special_sep_id = 102; + special_pad_id = 0; + special_mask_id = 103; + + add_sep = true; + } else if (tokenizer_model == "gpt2") { + type = LLAMA_VOCAB_TYPE_BPE; + + // read bpe merges and populate bpe ranks + const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + if (merges_keyidx == -1) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + + std::string first; + std::string second; + + const size_t pos = word.find(' ', 1); + + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); } - } break; - case LLAMA_VOCAB_TYPE_WPM: - { - if (add_special) { - GGML_ASSERT(vocab.special_cls_id != -1); - output.push_back(vocab.special_cls_id); + + bpe_ranks.emplace(std::make_pair(first, second), i); + } + + // default special tokens + special_bos_id = 11; + special_eos_id = 11; + special_unk_id = LLAMA_TOKEN_NULL; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + special_mask_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "t5") { + type = LLAMA_VOCAB_TYPE_UGM; + + // default special tokens + special_bos_id = LLAMA_TOKEN_NULL; + special_eos_id = 1; + special_unk_id = 2; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = 0; + special_mask_id = LLAMA_TOKEN_NULL; + + const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); + if (precompiled_charsmap_keyidx != -1) { + const gguf_type pc_type = gguf_get_arr_type(ctx, precompiled_charsmap_keyidx); + GGML_ASSERT(pc_type == GGUF_TYPE_INT8 || pc_type == GGUF_TYPE_UINT8); + + const size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); + const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); + precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); +#ifdef IS_BIG_ENDIAN + // correct endiannes of data in precompiled_charsmap binary blob + uint32_t * xcda_blob_size = (uint32_t *) &precompiled_charsmap[0]; + *xcda_blob_size = __builtin_bswap32(*xcda_blob_size); + assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap); + size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t); + uint32_t * xcda_array = (uint32_t *) &precompiled_charsmap[sizeof(uint32_t)]; + for (size_t i = 0; i < xcda_array_size; ++i) { + xcda_array[i] = __builtin_bswap32(xcda_array[i]); } +#endif + } + } else if (tokenizer_model == "rwkv") { + type = LLAMA_VOCAB_TYPE_RWKV; + + // default special tokens + special_bos_id = LLAMA_TOKEN_NULL; + special_eos_id = LLAMA_TOKEN_NULL; + special_unk_id = LLAMA_TOKEN_NULL; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "plamo2") { + type = LLAMA_VOCAB_TYPE_PLAMO2; + + // PLaMo-2 default special tokens (these will be overridden by model config) + special_bos_id = 1; // <|plamo:bos|> + special_eos_id = 2; // <|plamo:eos|> + special_unk_id = 0; // <|plamo:unk|> + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = 3; // <|plamo:pad|> + special_mask_id = LLAMA_TOKEN_NULL; + } else { + throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); + } + + // for now, only BPE models have pre-tokenizers + if (type == LLAMA_VOCAB_TYPE_BPE) { + add_space_prefix = false; + clean_spaces = true; + if (tokenizer_pre.empty()) { + LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__); + LLAMA_LOG_WARN("%s: \n", __func__); + LLAMA_LOG_WARN("%s: ************************************ \n", __func__); + LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__); + LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__); + LLAMA_LOG_WARN("%s: ************************************ \n", __func__); + LLAMA_LOG_WARN("%s: \n", __func__); + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if (tokenizer_pre == "default") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if ( + tokenizer_pre == "llama3" || + tokenizer_pre == "llama-v3" || + tokenizer_pre == "llama-bpe"|| + tokenizer_pre == "falcon3" || + tokenizer_pre == "falcon-h1" || + tokenizer_pre == "pixtral" || + tokenizer_pre == "midm-2.0" || + tokenizer_pre == "lfm2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3; + ignore_merges = true; + add_bos = true; + } else if ( + tokenizer_pre == "deepseek-llm") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM; + clean_spaces = false; + } else if ( + tokenizer_pre == "deepseek-coder") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER; + clean_spaces = false; + } else if ( + tokenizer_pre == "deepseek-v3") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; + clean_spaces = false; + } else if ( + tokenizer_pre == "falcon") { + pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON; + } else if ( + tokenizer_pre == "mpt") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MPT; + } else if ( + tokenizer_pre == "starcoder") { + pre_type = LLAMA_VOCAB_PRE_TYPE_STARCODER; + } else if ( + tokenizer_pre == "gpt-2" || + tokenizer_pre == "phi-2" || + tokenizer_pre == "jina-es" || + tokenizer_pre == "jina-de" || + tokenizer_pre == "gigachat" || + tokenizer_pre == "jina-v2-es" || + tokenizer_pre == "jina-v2-de" || + tokenizer_pre == "a.x-4.0" || + tokenizer_pre == "mellum") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "jina-v1-en" || + tokenizer_pre == "jina-v2-code" || + tokenizer_pre == "roberta-bpe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + add_sep = true; + } else if ( + tokenizer_pre == "refact") { + pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT; + } else if ( + tokenizer_pre == "command-r") { + pre_type = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; + clean_spaces = false; + } else if ( + tokenizer_pre == "qwen2" || + tokenizer_pre == "deepseek-r1-qwen") { + pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; + clean_spaces = false; + } else if ( + tokenizer_pre == "stablelm2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_STABLELM2; + } else if ( + tokenizer_pre == "olmo") { + pre_type = LLAMA_VOCAB_PRE_TYPE_OLMO; + } else if ( + tokenizer_pre == "dbrx") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DBRX; + } else if ( + tokenizer_pre == "smaug-bpe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SMAUG; + } else if ( + tokenizer_pre == "poro-chat") { + pre_type = LLAMA_VOCAB_PRE_TYPE_PORO; + clean_spaces = false; + } else if ( + tokenizer_pre == "glm4" || + tokenizer_pre == "chatglm-bpe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; + special_bos_id = LLAMA_TOKEN_NULL; + } else if ( + tokenizer_pre == "viking") { + pre_type = LLAMA_VOCAB_PRE_TYPE_VIKING; + clean_spaces = false; + } else if ( + tokenizer_pre == "jais") { + pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS; + } else if ( + tokenizer_pre == "tekken") { + pre_type = LLAMA_VOCAB_PRE_TYPE_TEKKEN; + clean_spaces = false; + ignore_merges = true; + add_bos = true; + } else if ( + tokenizer_pre == "smollm") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SMOLLM; + clean_spaces = false; + } else if ( + tokenizer_pre == "codeshell") { + pre_type = LLAMA_VOCAB_PRE_TYPE_CODESHELL; + } else if ( + tokenizer_pre == "bloom") { + pre_type = LLAMA_VOCAB_PRE_TYPE_BLOOM; + } else if ( + tokenizer_pre == "gpt3-finnish") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH; + } else if ( + tokenizer_pre == "exaone") { + pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE; + } else if ( + tokenizer_pre == "exaone4") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "chameleon") { + pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; + add_bos = true; + clean_spaces = false; + } else if ( + tokenizer_pre == "minerva-7b") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MINERVA; + } else if ( + tokenizer_pre == "megrez") { + pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; + } else if ( + tokenizer_pre == "gpt-4o" || + tokenizer_pre == "llama4") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; + clean_spaces = false; + } else if ( + tokenizer_pre == "superbpe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE; + clean_spaces = false; + } else if ( + tokenizer_pre == "trillion") { + pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION; + clean_spaces = false; + } else if ( + tokenizer_pre == "bailingmoe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; + clean_spaces = false; + } else if ( + tokenizer_pre == "seed-coder") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER; + clean_spaces = false; + } else if ( + tokenizer_pre == "hunyuan") { + pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN; + clean_spaces = false; + } else if ( + tokenizer_pre == "hunyuan-dense") { + pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE; + clean_spaces = false; + } else if ( + tokenizer_pre == "kimi-k2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; + clean_spaces = false; + } else { + throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); + } + } else if (type == LLAMA_VOCAB_TYPE_SPM) { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + add_space_prefix = true; + clean_spaces = false; + add_bos = true; + add_eos = false; + } else if (type == LLAMA_VOCAB_TYPE_WPM) { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + add_space_prefix = false; + clean_spaces = true; + add_bos = true; + add_eos = false; + add_sep = true; + } else if (type == LLAMA_VOCAB_TYPE_UGM) { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + add_bos = false; + add_eos = true; + } else if (type == LLAMA_VOCAB_TYPE_RWKV) { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + add_space_prefix = false; + clean_spaces = false; + add_bos = false; + add_eos = false; + } else { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } - llm_tokenizer_wpm tokenizer(vocab); + ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX, add_space_prefix, false); + ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, remove_extra_whitespaces, false); + } - for (const auto & fragment : fragment_buffer) { - if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { - auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); + const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); + if (token_idx == -1) { + throw std::runtime_error("cannot find tokenizer vocab in model file\n"); + } -#ifdef PRETOKENIZERDEBUG - LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); -#endif - tokenizer.tokenize(raw_text, output); - } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - output.push_back(fragment.token); + const float * scores = nullptr; + const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); + if (score_idx != -1) { + scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + } + + const int * toktypes = nullptr; + const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); + if (toktype_idx != -1) { + toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); + } + + uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx); + id_to_token.resize(n_tokens); + + for (uint32_t i = 0; i < n_tokens; i++) { + std::string word = gguf_get_arr_str(ctx, token_idx, i); + if (word.empty()) { + LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i); + word = "[EMPTY_" + std::to_string(i) + "]"; + } + + token_to_id[word] = i; + max_token_len = std::max(max_token_len, (int) word.size()); + + auto & token_data = id_to_token[i]; + token_data.text = std::move(word); + token_data.score = scores ? scores[i] : 0.0f; + token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; + + if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file + switch(toktypes[i]) { + case LLAMA_TOKEN_TYPE_UNKNOWN: token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN; break; + case LLAMA_TOKEN_TYPE_UNUSED: token_data.attr = LLAMA_TOKEN_ATTR_UNUSED; break; + case LLAMA_TOKEN_TYPE_NORMAL: token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; break; + case LLAMA_TOKEN_TYPE_CONTROL: token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; break; + case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break; + case LLAMA_TOKEN_TYPE_BYTE: token_data.attr = LLAMA_TOKEN_ATTR_BYTE; break; + case LLAMA_TOKEN_TYPE_UNDEFINED: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; + default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; + } + } + } + GGML_ASSERT(id_to_token.size() == token_to_id.size()); + + init_tokenizer(type); + + // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' + if (type == LLAMA_VOCAB_TYPE_SPM) { + try { + linefeed_id = vocab.byte_to_token('\n'); + } catch (const std::exception & e) { + LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what()); + linefeed_id = special_pad_id; + } + } else if (type == LLAMA_VOCAB_TYPE_WPM) { + linefeed_id = special_pad_id; + } else if (type == LLAMA_VOCAB_TYPE_RWKV) { + const std::vector ids = tokenize("\n", false); + GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); + linefeed_id = ids[0]; + } else { + const std::vector ids = tokenize("\n", false); + + //GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); + if (ids.empty()) { + LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__); + linefeed_id = special_pad_id; + } else { + linefeed_id = ids[0]; + } + } + + // special tokens + { + const std::vector> special_token_types = { + { LLM_KV_TOKENIZER_BOS_ID, special_bos_id }, + { LLM_KV_TOKENIZER_EOS_ID, special_eos_id }, + { LLM_KV_TOKENIZER_EOT_ID, special_eot_id }, + { LLM_KV_TOKENIZER_EOM_ID, special_eom_id }, + { LLM_KV_TOKENIZER_UNK_ID, special_unk_id }, + { LLM_KV_TOKENIZER_SEP_ID, special_sep_id }, + { LLM_KV_TOKENIZER_PAD_ID, special_pad_id }, + { LLM_KV_TOKENIZER_MASK_ID, special_mask_id }, + { LLM_KV_TOKENIZER_FIM_PRE_ID, special_fim_pre_id }, + { LLM_KV_TOKENIZER_FIM_SUF_ID, special_fim_suf_id }, + { LLM_KV_TOKENIZER_FIM_MID_ID, special_fim_mid_id }, + { LLM_KV_TOKENIZER_FIM_PAD_ID, special_fim_pad_id }, + { LLM_KV_TOKENIZER_FIM_REP_ID, special_fim_rep_id }, + { LLM_KV_TOKENIZER_FIM_SEP_ID, special_fim_sep_id }, + + // deprecated + { LLM_KV_TOKENIZER_PREFIX_ID, special_fim_pre_id }, + { LLM_KV_TOKENIZER_SUFFIX_ID, special_fim_suf_id }, + { LLM_KV_TOKENIZER_MIDDLE_ID, special_fim_mid_id }, + }; + + for (const auto & it : special_token_types) { + const std::string & key = kv(std::get<0>(it)); + int32_t & id = std::get<1>(it); + + uint32_t new_id; + if (!ml.get_key(std::get<0>(it), new_id, false)) { + continue; + } + if (new_id >= id_to_token.size()) { + LLAMA_LOG_WARN("%s: bad special token: '%s' = %u, using default id %d\n", + __func__, key.c_str(), new_id, id); + } else { + id = new_id; + } + } + + // Handle add_bos, add_eos and add_sep + { + bool temp = true; + + if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) { + add_bos = temp; + } + if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) { + add_eos = temp; + } + if (ml.get_key(LLM_KV_TOKENIZER_ADD_SEP, temp, false)) { + add_sep = temp; + } + } + + // auto-detect special tokens by text + // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_... + // for now, we apply this workaround to find the tokens based on their text + + for (const auto & t : token_to_id) { + // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc. + if (special_eot_id == LLAMA_TOKEN_NULL) { + if (false + || t.first == "<|eot_id|>" + || t.first == "<|im_end|>" + || t.first == "<|end|>" + || t.first == "" + || t.first == "<|endoftext|>" + || t.first == "" + || t.first == "_" + || t.first == "<|end▁of▁sentence|>" // DeepSeek + || t.first == "" // smoldocling + ) { + special_eot_id = t.second; + if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", + __func__, t.second, t.first.c_str()); + id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; } } + } - if (add_special) { - GGML_ASSERT(vocab.special_sep_id != -1); - output.push_back(vocab.special_sep_id); + // find EOM token: "<|eom_id|>" + if (special_eom_id == LLAMA_TOKEN_NULL) { + if (false + || t.first == "<|eom_id|>" + ) { + special_eom_id = t.second; + if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", + __func__, t.second, t.first.c_str()); + id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + } } - } break; - case LLAMA_VOCAB_TYPE_UGM: - { - llm_tokenizer_ugm tokenizer(vocab); + } - if (add_special && vocab.tokenizer_add_bos != 0) { - GGML_ASSERT(vocab.special_bos_id != -1); - output.push_back(vocab.special_bos_id); + // find FIM_PRE token: "<|fim_prefix|>", "", "
", etc.
+            if (special_fim_pre_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_prefix|>"  // Qwen
+                        || t.first == ""
+                        || t.first == ""    // Granite
+                        || t.first == "<|fim▁begin|>" // DeepSeek
+                        || t.first == "
"
+                        || t.first == "▁
"          // CodeLlama
+                        || t.first == "<|code_prefix|>" // GLM-4.5
+                        ) {
+                    special_fim_pre_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
                 }
+            }
 
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        tokenizer.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
+            // find FIM_SUF token: "<|fim_suffix|>", "", "", etc.
+            if (special_fim_suf_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_suffix|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""   // Granite
+                        || t.first == "<|fim▁hole|>" // DeepSeek
+                        || t.first == ""
+                        || t.first == "▁"         // CodeLlama
+                        || t.first == "<|code_suffix|>" // GLM-4.5
+                        ) {
+                    special_fim_suf_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
                     }
                 }
+            }
 
-                if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
-                    LLAMA_LOG_WARN(
-                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
-                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
-                        "Are you sure this is what you want?\n", __FUNCTION__);
+            // find FIM_MID token: "<|fim_middle|>", "", "", etc.
+            if (special_fim_mid_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_middle|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""   // Granite
+                        || t.first == "<|fim▁end|>"  // DeepSeek
+                        || t.first == ""
+                        || t.first == "▁"         // CodeLlama
+                        || t.first == "<|code_middle|>" // GLM-4.5
+                        ) {
+                    special_fim_mid_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_PAD token: "<|fim_pad|>", "", "", etc.
+            if (special_fim_pad_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_pad|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""   // Granite
+                        || t.first == ""
+                        ) {
+                    special_fim_pad_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_REP token: "<|fim_repo|>", "", "", etc.
+            if (special_fim_rep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_repo|>"  // Qwen
+                        || t.first == "<|repo_name|>"
+                        || t.first == ""
+                        || t.first == ""
+                        || t.first == ""    // Granite
+                        ) {
+                    special_fim_rep_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_SEP token: "<|file_sep|>"
+            if (special_fim_sep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|file_sep|>" // Qwen
+                        ) {
+                    special_fim_sep_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+        }
+
+        // maintain a list of tokens that cause end-of-generation
+        // this is currently determined based on the token text, which is obviously not ideal
+        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
+        special_eog_ids.clear();
+
+        if (special_fim_pad_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_pad_id) == 0) {
+            special_eog_ids.insert(special_fim_pad_id);
+        }
+
+        if (special_fim_rep_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_rep_id) == 0) {
+            special_eog_ids.insert(special_fim_rep_id);
+        }
+
+        if (special_fim_sep_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_sep_id) == 0) {
+            special_eog_ids.insert(special_fim_sep_id);
+        }
+
+        for (const auto & t : token_to_id) {
+            if (false
+                    || t.first == "<|eot_id|>"
+                    || t.first == "<|im_end|>"
+                    || t.first == "<|end|>"
+                    || t.first == "<|return|>" // o200k_harmony
+                    || t.first == "<|call|>"   // o200k_harmony
+                    || t.first == ""
+                    || t.first == "<|endoftext|>"
+                    || t.first == "<|eom_id|>"
+                    || t.first == ""
+                    || t.first == "_"
+                    || t.first == "<|end_of_text|>"
+                    || t.first == "" // smoldocling
+               ) {
+                special_eog_ids.insert(t.second);
+                if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                            __func__, t.second, t.first.c_str());
+                    id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                }
+            } else {
+                // token is control, but not marked as EOG -> print a debug log
+                if (id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && special_eog_ids.count(t.second) == 0) {
+                    LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
+                            __func__, t.second, t.first.c_str());
+                }
+            }
+        }
+
+        // @ngxson : quick hack for gpt-oss, always render these tokens
+        for (const auto & t : token_to_id) {
+            if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>") {
+                id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
+            }
+        }
+
+        // sanity checks
+        if (special_eos_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eos_id) == 0) {
+            special_eog_ids.insert(special_eos_id);
+            LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (special_eot_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eot_id) == 0) {
+            special_eog_ids.insert(special_eot_id);
+            LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (special_eom_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eom_id) == 0) {
+            special_eog_ids.insert(special_eom_id);
+            LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        // TODO: workaround for o200k_harmony tokenizer: the "<|end|>" token should not be EOG
+        //       we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens,
+        //       we remove the "<|end|>" token from the EOG list
+        {
+            bool has_return = false;
+            bool has_call   = false;
+            bool has_end    = false;
+
+            llama_token end_id = LLAMA_TOKEN_NULL;
+
+            LLAMA_LOG_INFO("%s: printing all EOG tokens:\n", __func__);
+            for (auto tid : special_eog_ids) {
+                LLAMA_LOG_INFO("%s:   - %d ('%s')\n", __func__, tid, id_to_token[tid].text.c_str());
+
+                if (id_to_token[tid].text == "<|return|>") {
+                    has_return = true;
+                } else if (id_to_token[tid].text == "<|call|>") {
+                    has_call = true;
+                } else if (id_to_token[tid].text == "<|end|>") {
+                    has_end = true;
+                    end_id = tid;
+                }
+            }
+
+            if (has_return && has_call && has_end) {
+                special_eog_ids.erase(end_id);
+                LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
+            }
+        }
+    }
+
+    // build special tokens cache
+    {
+        for (llama_token id = 0; id < (llama_token) n_tokens; ++id) {
+            if (id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
+                cache_special_tokens.push_back(id);
+            }
+        }
+
+        std::sort(cache_special_tokens.begin(), cache_special_tokens.end(),
+            [&] (const llama_token a, const llama_token b) {
+                return id_to_token[a].text.size() > id_to_token[b].text.size();
+            }
+        );
+
+        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t) cache_special_tokens.size());
+    }
+
+    // build token to piece cache
+    {
+        size_t size_cache = 0;
+
+        std::vector cache(n_tokens);
+
+        for (uint32_t id = 0; id < n_tokens; ++id) {
+            cache[id] = token_to_piece_for_cache(id, true);
+
+            size_cache += cache[id].size();
+        }
+
+        std::swap(cache_token_to_piece, cache);
+
+        LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
+    }
+
+    // Handle per token attributes
+    //NOTE: Each model customizes per token attributes.
+    //NOTE: Per token attributes are missing from the GGUF file.
+    //TODO: Extract attributes from GGUF file.
+    {
+        auto _contains_any = [] (const std::string & str, const std::vector & substrs) -> bool {
+            for (const auto & substr : substrs) {
+                if (str.find(substr) != std::string::npos) {
+                    return true;
+                }
+            }
+            return false;
+        };
+
+        auto _set_tokenid_attr = [&] (const llama_token id, llama_token_attr attr, bool value) {
+            uint32_t current = id_to_token.at(id).attr;
+            current = value ? (current | attr) : (current & ~attr);
+            id_to_token[id].attr = (llama_token_attr) current;
+        };
+
+        auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
+            _set_tokenid_attr(token_to_id.at(token), attr, value);
+        };
+
+        std::string model_name;
+        std::string tokenizer_pre;
+        std::string general_arch;
+
+        ml.get_key(LLM_KV_GENERAL_NAME,  model_name,    false);
+        ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
+        ml.get_key(LLM_KV_GENERAL_ARCHITECTURE, general_arch, false);
+
+        // model name to lowercase
+        std::transform(model_name.begin(), model_name.end(), model_name.begin(),
+            [] (const std::string::value_type x) {
+                return std::tolower(x);
+            }
+        );
+
+        // set attributes by model/tokenizer/architecture name
+        if (false
+                || _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})
+                || _contains_any(general_arch, {"nomic-bert-moe"})
+           ) {
+            if (token_to_id.count("") == 0) {
+                LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__);
+            } else {
+                _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
+            }
+        } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
+            for (auto id : cache_special_tokens) {
+                _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (const auto * token : {""}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (const auto * token : {"", "", "<|endoftext|>"}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
+            }
+        }
+    }
+}
+
+enum llama_vocab_type llama_vocab::impl::get_type() const {
+    return type;
+}
+
+std::string llama_vocab::impl::type_name() const{
+    switch (type) {
+        case LLAMA_VOCAB_TYPE_NONE:   return "no vocab";
+        case LLAMA_VOCAB_TYPE_SPM:    return "SPM";
+        case LLAMA_VOCAB_TYPE_BPE:    return "BPE";
+        case LLAMA_VOCAB_TYPE_WPM:    return "WPM";
+        case LLAMA_VOCAB_TYPE_UGM:    return "UGM";
+        case LLAMA_VOCAB_TYPE_RWKV:   return "RWKV";
+        case LLAMA_VOCAB_TYPE_PLAMO2: return "PLaMo2";
+        default:                      return "unknown";
+    }
+}
+
+bool llama_vocab::impl::is_normal(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
+}
+
+bool llama_vocab::impl::is_unknown(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
+}
+
+bool llama_vocab::impl::is_control(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
+}
+
+bool llama_vocab::impl::is_byte(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
+}
+
+bool llama_vocab::impl::is_user_defined(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
+}
+
+bool llama_vocab::impl::is_unused(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED;
+}
+
+bool llama_vocab::impl::is_eog(llama_token id) const {
+    return id != LLAMA_TOKEN_NULL && special_eog_ids.count(id) > 0;
+}
+
+uint8_t llama_vocab::impl::token_to_byte(llama_token id) const {
+    GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
+    GGML_ASSERT(is_byte(id));
+    const auto & token_data = id_to_token.at(id);
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
+            auto buf = token_data.text.substr(3, 2);
+            return strtol(buf.c_str(), NULL, 16);
+        }
+        case LLAMA_VOCAB_TYPE_BPE: {
+            GGML_ABORT("fatal error");
+        }
+        case LLAMA_VOCAB_TYPE_WPM: {
+            GGML_ABORT("fatal error");
+        }
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+llama_token_attr llama_vocab::impl::token_get_attr(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token.at(id).attr;
+}
+
+void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) {
+    LLAMA_LOG_DEBUG("%s: initializing tokenizer for type %d\n", __func__, type);
+
+    switch (type) {
+        case LLAMA_VOCAB_TYPE_SPM:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_BPE:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_WPM:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_UGM:
+            tokenizer = std::make_unique(vocab, precompiled_charsmap);
+            break;
+        case LLAMA_VOCAB_TYPE_RWKV:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_PLAMO2:
+            tokenizer = std::make_unique(vocab);
+            break;
+        default:
+            GGML_ABORT("unsupported vocab type");
+    }
+}
+
+//
+// (de-) tokenize
+//
+
+// #define PRETOKENIZERDEBUG
+
+void llama_vocab::impl::tokenizer_st_partition(std::forward_list & buffer, bool parse_special) const {
+    // for each special token
+    for (const llama_token special_id : cache_special_tokens) {
+        const auto & data = vocab.get_token_data(special_id);
+        const auto & text = data.text;
+
+        if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) {
+            // Ignore control and unknown tokens when parse_special == false
+            continue;
+            // User-defined tokens are still pre-tokenized before everything else
+            // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726
+            // This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.)
+        }
+
+        // for each text fragment
+        std::forward_list::iterator it = buffer.begin();
+        while (it != buffer.end()) {
+            auto & fragment = (*it);
+
+            // if a fragment is text ( not yet processed )
+            if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                const auto & raw_text = fragment.raw_text;
+
+                auto raw_text_base_offset = fragment.offset;
+                auto raw_text_base_length = fragment.length;
+
+                // loop over the text
+                while (true) {
+                    // find the first occurrence of a given special token in this fragment
+                    //  passing offset argument only limit the "search area" but match coordinates
+                    //  are still relative to the source full raw_text
+                    //  string_view begins at pos 0 for the same reason
+                    auto match = std::string_view(raw_text.data(), raw_text_base_offset + raw_text_base_length).find(text, raw_text_base_offset);
+
+                    // no occurrences found, stop processing this fragment for a given special token
+                    if (match == std::string::npos) break;
+
+#ifdef PRETOKENIZERDEBUG
+                    LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
+#endif
+                    auto source = std::distance(buffer.begin(), it);
+
+                    // if match is further than base offset
+                    //  then we have some text to the left of it
+                    if (match > raw_text_base_offset) {
+                        // left
+                        const int64_t left_reminder_offset = raw_text_base_offset + 0;
+                        int64_t left_reminder_length = match - raw_text_base_offset;
+
+                        if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) {
+                            while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) {
+                                left_reminder_length--;
+                            }
+                        }
+
+                        if (left_reminder_length > 0) {
+                            buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
+                            it++;
+                        }
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
+#endif
+                    }
+
+                    // special token
+                    buffer.emplace_after(it, special_id);
+                    it++;
+
+                    // right
+                    if (match + text.length() < raw_text_base_offset + raw_text_base_length) {
+                        int64_t right_reminder_offset = match + text.length();
+                        int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + text.length());
+
+                        if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
+                            while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
+                                right_reminder_offset++;
+                                right_reminder_length--;
+                            }
+                        }
+
+                        if (right_reminder_length > 0) {
+                            buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
+                            it++;
+                        }
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
+#endif
+
+                        if (source == 0) {
+                            buffer.erase_after(buffer.before_begin());
+                        } else {
+                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
+                        }
+
+                        // repeat for the right side
+                        raw_text_base_offset = right_reminder_offset;
+                        raw_text_base_length = right_reminder_length;
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
+#endif
+                    } else {
+                        if (source == 0) {
+                            buffer.erase_after(buffer.before_begin());
+                        } else {
+                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
+                        }
+                        break;
+                    }
                 }
+            }
+            it++;
+        }
+    }
+}
+
+// NOTE: avoid ever using this except for building the token_to_piece caches
+std::string llama_vocab::impl::token_to_piece_for_cache(llama_token token, bool special) const {
+    std::string piece;
+    piece.resize(piece.capacity());  // using string internal cache
+    const int n_chars = vocab.token_to_piece(token, &piece[0], piece.size(), 0, special);
+    if (n_chars < 0) {
+        piece.resize(-n_chars);
+        int check = vocab.token_to_piece(token, &piece[0], piece.size(), 0, special);
+        GGML_ASSERT(check == -n_chars);
+    }
+    else {
+        piece.resize(n_chars);
+    }
+
+    return piece;
+}
+
+static void llama_escape_whitespace(std::string & text) {
+    replace_all(text, " ", "\xe2\x96\x81");
+}
+
+static void llama_unescape_whitespace(std::string & word) {
+    replace_all(word, "\xe2\x96\x81", " ");
+}
+
+static std::string llama_decode_text(const std::string & text) {
+    std::string decoded_text;
+
+    const auto cpts = unicode_cpts_from_utf8(text);
+    for (const auto cpt : cpts) {
+        const auto utf8 = unicode_cpt_to_utf8(cpt);
+        try {
+            decoded_text += unicode_utf8_to_byte(utf8);
+        } catch (const std::out_of_range & /*e*/) {
+            decoded_text += "[UNK_BYTE_0x";
+            for (const auto c : utf8) {
+                decoded_text += format("%02x", (uint8_t) c);
+            }
+            decoded_text += text + "]";
+        }
+    }
+
+    return decoded_text;
+}
+
+std::vector llama_vocab::impl::tokenize(
+        const std::string & raw_text,
+        bool add_special,
+        bool parse_special) const {
+    GGML_ASSERT(tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+
+    std::vector output;
+    std::forward_list fragment_buffer;
+
+    if (!raw_text.empty()) {
+        fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
+        tokenizer_st_partition(fragment_buffer, parse_special);
+    }
+
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+            {
+                // OG tokenizer behavior:
+                //
+                // tokenizer.encode('', add_special_tokens=True)  returns [1]
+                // tokenizer.encode('', add_special_tokens=False) returns []
+
+                bool is_prev_special = true;  // prefix with space if first token
+
+                if (add_special && add_bos) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                    is_prev_special = true;
+                }
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text;
+
+                        // prefix with space if previous is special
+                        if (add_space_prefix && is_prev_special) {
+                            text = ' ';
+                        }
+
+                        text += fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        llama_escape_whitespace(text);
+                        llm_tokenizer_spm_session session(vocab);
+                        session.tokenize(text, output);
+                        is_prev_special = false;
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                        is_prev_special = true;
+                    }
+                }
+
+                if (add_special && add_bos && output.size() >= 2 && output[1] == special_bos_id) {
+                    LLAMA_LOG_WARN(
+                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                        "Are you sure this is what you want?\n", __FUNCTION__);
+                }
+
+                if (add_special && add_eos) {
+                    GGML_ASSERT(special_eos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_eos_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_BPE:
+            {
+                llm_tokenizer_bpe_session session(vocab, *static_cast(tokenizer.get()));
+                // it calls some other methods that are not exist in llm_tokenizer,
+                // here just cast it to bpe tokenizer object
+                if (add_special) {
+                    session.append_bos(output);
+                }
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        session.append(fragment.token, output);
+                    }
+                }
+
+                if (add_special) {
+                    session.append_eos(output);
+                    session.check_double_bos_eos(output);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_WPM:
+            {
+                if (add_special) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                }
+
+                llm_tokenizer_wpm_session session(vocab);
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+
+                if (add_special) {
+                    GGML_ASSERT(special_sep_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_sep_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_UGM:
+            {
+                if (add_special && add_bos) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                }
+                llm_tokenizer_ugm_session session(vocab, *static_cast(tokenizer.get()));
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+
+                if (add_special && add_bos && output.size() >= 2 && output[1] == special_bos_id) {
+                    LLAMA_LOG_WARN(
+                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                        "Are you sure this is what you want?\n", __FUNCTION__);
+                }
+
+                if (add_special && add_eos) {
+                    GGML_ASSERT(special_eos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_eos_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_RWKV:
+            {
+                llm_tokenizer_rwkv_session session(vocab, *static_cast(tokenizer.get()));
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_PLAMO2:
+            {
+                llm_tokenizer_plamo2_session session(*static_cast(tokenizer.get()));
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_NONE:
+            GGML_ABORT("fatal error");
+    }
+
+    return output;
+}
+
+int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
+    // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
+    static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
+    const llama_token_attr attr = token_get_attr(token);
+    if (!special && (attr & attr_special)) {
+        return 0;
+    }
+
+    // copy piece chars to output text buffer
+    // skip up to 'lstrip' leading spaces before copying
+    auto _try_copy = [=] (const char * token, size_t size) -> int32_t {
+        if (size >= static_cast(std::numeric_limits::max())) {
+            GGML_ABORT("invalid token size: %zu exceeds int32_t limit", size);
+        }
+
+        for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) {
+            token++;
+            size--;
+        }
+        if (length < (int32_t)size) {
+            return -(int32_t) size;
+        }
+        memcpy(buf, token, size);
+        return (int32_t) size;
+    };
+
+    // if we have a cache - use it
+    {
+        const auto & cache = cache_token_to_piece;
+
+        if (!cache.empty()) {
+            const auto & result = cache.at(token);
+            return _try_copy(result.data(), result.size());
+        }
+    }
+
+    if (0 <= token && token < (int32_t) id_to_token.size()) {
+        const std::string & token_text = id_to_token[token].text;
+        switch (get_type()) {
+            case LLAMA_VOCAB_TYPE_WPM:
+            case LLAMA_VOCAB_TYPE_SPM:
+            case LLAMA_VOCAB_TYPE_UGM: {
+                // NOTE: we accept all unsupported token types,
+                // suppressing them like CONTROL tokens.
+                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
+                    return _try_copy(token_text.data(), token_text.size());
+                }
+                if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+                    std::string result = token_text;
+                    llama_unescape_whitespace(result);
+                    return _try_copy(result.data(), result.size());
+                }
+                if (attr & LLAMA_TOKEN_ATTR_BYTE) {
+                    char byte = (char) token_to_byte(token);
+                    return _try_copy((char*) &byte, 1);
+                }
+                break;
+            }
+            case LLAMA_VOCAB_TYPE_BPE: {
+                // NOTE: we accept all unsupported token types,
+                // suppressing them like CONTROL tokens.
+                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
+                    return _try_copy(token_text.data(), token_text.size());
+                }
+                if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+                    std::string result = llama_decode_text(token_text);
+                    return _try_copy(result.data(), result.size());
+                }
+                break;
+            }
+            case LLAMA_VOCAB_TYPE_RWKV: {
+                std::vector result = llama_unescape_rwkv_token(token_text);
+
+                // If we don't have enough space, return an error
+                if (result.size() > (size_t)length) {
+                    return -(int)result.size();
+                }
+
+                memcpy(buf, result.data(), result.size());
+                return (int)result.size();
+            }
+            case LLAMA_VOCAB_TYPE_PLAMO2: {
+                // PLaMo-2 uses similar token handling as BPE/SPM
+                if (vocab.is_byte(token)) {
+                    // Handle byte tokens like <0xXX>
+                    if (token_text.length() == 6 && token_text.substr(0, 3) == "<0x" && token_text.back() == '>') {
+                        int hex_val = std::stoi(token_text.substr(3, 2), nullptr, 16);
+                        if (length < 1) {
+                            return -1;
+                        }
+                        buf[0] = static_cast(hex_val);
+                        return 1;
+                    }
+                }
+
+                // Normal token - just copy the text
+                std::string result = token_text;
+                return _try_copy(result.data(), result.size());
+            }
+            default:
+                GGML_ABORT("fatal error");
+        }
+    }
+
+    return 0;
+}
+
+const std::string & llama_vocab::impl::token_to_piece(llama_token token) const {
+    return cache_token_to_piece.at(token);
+}
+
+int32_t llama_vocab::impl::detokenize(
+               const llama_token * tokens,
+                         int32_t   n_tokens,
+                            char * text,
+                         int32_t   text_len_max,
+                            bool   remove_special,
+                            bool   unparse_special) const {
+    if (type == LLAMA_VOCAB_TYPE_NONE) {
+        return 0;
+    }
+
+    GGML_ASSERT(tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+
+    int32_t avail = text_len_max;
+    int32_t total = 0;
+
+    // remove the leading space
+    bool remove_space = add_space_prefix;
+
+    if (remove_special && add_bos) {
+        if (n_tokens > 0 && tokens[0] == special_bos_id) {
+            remove_space = false;
+            n_tokens--;
+            tokens++;
+        }
+    }
+
+    if (remove_special && add_eos) {
+        if (n_tokens > 0 && tokens[n_tokens - 1] == special_eos_id) {
+            n_tokens--;
+        }
+    }
+
+    for (int32_t i = 0; i < n_tokens; ++i) {
+        GGML_ASSERT(avail >= 0);
+        int32_t n_chars = token_to_piece(tokens[i], text, avail, remove_space, unparse_special);
+        remove_space = false;
+        if (n_chars < 0) {
+            avail = 0;
+            total -= n_chars;
+        } else if (n_chars > 0) {
+            avail -= n_chars;
+            text  += n_chars;
+            total += n_chars;
+        }
+    }
+
+    if (total > text_len_max) {
+        return -total;
+    }
+
+    if (clean_spaces) {
+        text -= total;  // restart text
+
+        // first pass: characters ?!.,  //TODO: where do these characters come from?
+        const int32_t total1 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total1; ++i) {
+            const char x = text[i];
+            if (text[i - 1] == ' ') {
+                if (x == '?' || x == '!' || x == '.' || x == ',') {  // " ?", " !", " .", " ,"
+                    total--;  // remove space
+                }
+            }
+            text[total++] = x;
+        }
+
+        // second pass: strip single apostrophe between spaces
+        const int32_t total2 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total2; ++i) {
+            const char x = text[i];
+            if (x == '\'' && i + 1 < total2 && text[i - 1] == ' ' && text[i + 1] == ' ') {  // " ' "
+                total--;           // remove prev space
+                text[++i] = '\0';  // remove next space
+            }
+            text[total++] = x;
+        }
+
+        // third pass: apostrophe contractions  //NOTE: this makes sense?
+        const int32_t total3 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total3; ++i) {
+            const char x = text[i];
+            if (text[i - 1] == ' ') {
+                if (x == '\'' && i + 1 < total3) {
+                    const char x1 = text[i + 1];
+                    if (x1 == 't' || x1 == 'd') {  // " 't", " 'd"
+                        //total--;  // remove space
+                    } else if (x1 == 's' || x1 == 'm') {  // " 's", " 'm"
+                        total--;  // remove space
+                    } else if (i + 2 < total3) {
+                        const char x2 = text[i + 2];
+                        if ((x1 == 'l' && x2 == 'l')) {  // " 'll"
+                            //total--;  // remove space
+                        } else if ((x1 == 'r' && x2 == 'e') || (x1 == 'v' && x2 == 'e')) {  // " 're", " 've"
+                            total--;  // remove space
+                        } else {
+                            //total--;  // remove space
+                        }
+                    } else {
+                        //total--;  // remove space
+                    }
+                }
+            }
+            text[total++] = x;
+        }
+    }
+
+    return total <= text_len_max ? total : -total;
+}
+
+void llama_vocab::impl::print_info() const {
+    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, type_name().c_str());
+    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, vocab.n_tokens());
+    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (uint32_t) bpe_ranks.size());
+
+    // special tokens
+    if (special_bos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, special_bos_id,     id_to_token.at(special_bos_id).text.c_str() );  }
+    if (special_eos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, special_eos_id,     id_to_token.at(special_eos_id).text.c_str() );  }
+    if (special_eot_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, special_eot_id,     id_to_token.at(special_eot_id).text.c_str() );  }
+    if (special_eom_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, special_eom_id,     id_to_token.at(special_eom_id).text.c_str() );  }
+    if (special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, special_unk_id,     id_to_token.at(special_unk_id).text.c_str() );  }
+    if (special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, special_sep_id,     id_to_token.at(special_sep_id).text.c_str() );  }
+    if (special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, special_pad_id,     id_to_token.at(special_pad_id).text.c_str() );  }
+    if (special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, special_mask_id,    id_to_token.at(special_mask_id).text.c_str() ); }
+
+    if (linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, linefeed_id,        id_to_token.at(linefeed_id).text.c_str() ); }
+
+    if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); }
+    if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); }
+    if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); }
+    if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); }
+    if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); }
+    if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); }
+
+    for (const auto & id : special_eog_ids) {
+        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() );
+    }
+
+    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
+}
+
+llama_vocab::llama_vocab() : pimpl(new impl(*this)) {
+}
+
+llama_vocab::~llama_vocab() {
+}
+
+void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
+    pimpl->load(ml, kv);
+}
+
+std::string llama_vocab::get_tokenizer_model() const {
+    return pimpl->tokenizer_model;
+}
+
+std::string llama_vocab::get_tokenizer_pre() const {
+    return pimpl->tokenizer_pre;
+}
+
+enum llama_vocab_type llama_vocab::get_type() const {
+    return pimpl->type;
+}
+
+enum llama_vocab_pre_type llama_vocab::get_pre_type() const {
+    return pimpl->pre_type;
+}
+
+uint32_t llama_vocab::n_tokens() const {
+    return (uint32_t) pimpl->id_to_token.size();
+}
+
+uint32_t llama_vocab::n_token_types() const {
+    return (uint32_t) pimpl->n_token_types;
+}
+
+std::string llama_vocab::type_name() const{
+    return pimpl->type_name();
+}
+
+bool llama_vocab::is_normal(llama_token id) const {
+    return pimpl->is_normal(id);
+}
+
+bool llama_vocab::is_unknown(llama_token id) const {
+    return pimpl->is_unknown(id);
+}
+
+bool llama_vocab::is_control(llama_token id) const {
+    return pimpl->is_control(id);
+}
+
+bool llama_vocab::is_byte(llama_token id) const {
+    return pimpl->is_byte(id);
+}
+
+bool llama_vocab::is_user_defined(llama_token id) const {
+    return pimpl->is_user_defined(id);
+}
+
+bool llama_vocab::is_unused(llama_token id) const {
+    return pimpl->is_unused(id);
+}
+
+bool llama_vocab::is_eog(llama_token id) const {
+    return pimpl->is_eog(id);
+}
+
+uint8_t llama_vocab::token_to_byte(llama_token id) const {
+    return pimpl->token_to_byte(id);
+}
+
+llama_token llama_vocab::byte_to_token(uint8_t ch) const {
+    GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
+    static const char * hex = "0123456789ABCDEF";
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
+            const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
+            auto token = pimpl->token_to_id.find(buf);
+            if (token != pimpl->token_to_id.end()) {
+                return (*token).second;
+            }
+            // Try to fall back to just the byte as a string
+            const char buf2[2] = { (char)ch, 0 };
+            return pimpl->token_to_id.at(buf2);
+        }
+        case LLAMA_VOCAB_TYPE_WPM:
+        case LLAMA_VOCAB_TYPE_BPE: {
+            return pimpl->token_to_id.at(unicode_byte_to_utf8(ch));
+        }
+        case LLAMA_VOCAB_TYPE_PLAMO2: {
+            // PLaMo-2 uses byte tokens in format <0xXX>
+            char hex_str[8];
+            snprintf(hex_str, sizeof(hex_str), "<0x%02X>", ch);
+            return pimpl->token_to_id.at(hex_str);
+        }
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+llama_token llama_vocab::text_to_token(const std::string & text) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    auto it = pimpl->token_to_id.find(text);
+    if (it != pimpl->token_to_id.end()) {
+        return (*it).second;
+    }
+    return LLAMA_TOKEN_NULL;
+}
+
+const llama_vocab::token_data & llama_vocab::get_token_data(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id);
+}
+
+const char * llama_vocab::token_get_text(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id).text.c_str();
+}
+
+float llama_vocab::token_get_score(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id).score;
+}
+
+llama_token_attr llama_vocab::token_get_attr(llama_token id) const {
+    return pimpl->token_get_attr(id);
+}
+
+llama_token llama_vocab::token_bos() const {
+    return pimpl->special_bos_id;
+}
+
+llama_token llama_vocab::token_eos() const {
+    return pimpl->special_eos_id;
+}
+
+llama_token llama_vocab::token_eot() const {
+    return pimpl->special_eot_id;
+}
+
+llama_token llama_vocab::token_eom() const {
+    return pimpl->special_eom_id;
+}
+
+llama_token llama_vocab::token_unk() const {
+    return pimpl->special_unk_id;
+}
+
+llama_token llama_vocab::token_sep() const {
+    return pimpl->special_sep_id;
+}
+
+llama_token llama_vocab::token_nl() const {
+    return pimpl->linefeed_id;
+}
+
+llama_token llama_vocab::token_pad() const {
+    return pimpl->special_pad_id;
+}
+
+llama_token llama_vocab::token_prefix() const {
+    return pimpl->special_fim_pre_id;
+}
+
+llama_token llama_vocab::token_middle() const {
+    return pimpl->special_fim_mid_id;
+}
+
+llama_token llama_vocab::token_suffix() const {
+    return pimpl->special_fim_suf_id;
+}
+
+llama_token llama_vocab::token_fim_pre() const {
+    return pimpl->special_fim_pre_id;
+}
+
+llama_token llama_vocab::token_fim_suf() const {
+    return pimpl->special_fim_suf_id;
+}
+
+llama_token llama_vocab::token_fim_mid() const {
+    return pimpl->special_fim_mid_id;
+}
+
+llama_token llama_vocab::token_fim_pad() const {
+    return pimpl->special_fim_pad_id;
+}
+
+llama_token llama_vocab::token_fim_rep() const {
+    return pimpl->special_fim_rep_id;
+}
+
+llama_token llama_vocab::token_fim_sep() const {
+    return pimpl->special_fim_sep_id;
+}
+
+llama_token llama_vocab::token_mask() const {
+    return pimpl->special_mask_id;
+}
+
+bool llama_vocab::get_add_space_prefix() const {
+    return pimpl->add_space_prefix;
+}
+
+bool llama_vocab::get_add_bos() const {
+    return pimpl->add_bos;
+}
+
+bool llama_vocab::get_add_eos() const {
+    return pimpl->add_eos;
+}
+
+bool llama_vocab::get_add_sep() const {
+    return pimpl->add_sep;
+}
+
+bool llama_vocab::get_ignore_merges() const {
+    return pimpl->ignore_merges;
+}
+
+bool llama_vocab::get_clean_spaces() const {
+    return pimpl->clean_spaces;
+}
+
+bool llama_vocab::get_remove_extra_whitespaces() const {
+    return pimpl->remove_extra_whitespaces;
+}
+
+bool llama_vocab::get_escape_whitespaces() const {
+    return pimpl->escape_whitespaces;
+}
+
+bool llama_vocab::get_treat_whitespace_as_suffix() const {
+    return pimpl->treat_whitespace_as_suffix;
+}
+
+int llama_vocab::max_token_len() const {
+    return pimpl->max_token_len;
+}
+
+int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
+    GGML_ASSERT(token_left.find(' ')   == std::string::npos);
+    GGML_ASSERT(token_left.find('\n')  == std::string::npos);
+    GGML_ASSERT(token_right.find(' ')  == std::string::npos);
+    GGML_ASSERT(token_right.find('\n') == std::string::npos);
+
+    auto it = pimpl->bpe_ranks.find(std::make_pair(token_left, token_right));
+    if (it == pimpl->bpe_ranks.end()) {
+        return -1;
+    }
+
+    return it->second;
+}
+
+std::vector llama_vocab::get_bpe_merges() const {
+    std::vector result(pimpl->bpe_ranks.size());
+
+    for (const auto & pair : pimpl->bpe_ranks) {
+        result[pair.second] = pair.first.first + " " + pair.first.second;
+    }
+
+    return result;
+}
+
+std::vector llama_vocab::get_precompiled_charsmap() const {
+    return pimpl->precompiled_charsmap;
+}
+
+int32_t llama_vocab::tokenize(
+                  const char * text,
+                     int32_t   text_len,
+                 llama_token * tokens,
+                     int32_t   n_tokens_max,
+                        bool   add_special,
+                        bool   parse_special) const {
+    auto res = tokenize(std::string(text, text_len), add_special, parse_special);
+    if (res.size() >= static_cast(std::numeric_limits::max())) {
+        LLAMA_LOG_ERROR("%s: tokenization result size %zu exceeds int32_t limit\n", __func__, res.size());
+        return std::numeric_limits::min();
+    }
+
+    if (n_tokens_max < (int) res.size()) {
+        // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
+        return -((int) res.size());
+    }
 
-                if (add_special && vocab.tokenizer_add_eos == 1) {
-                    GGML_ASSERT(vocab.special_eos_id != -1);
-                    output.push_back(vocab.special_eos_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_NONE:
-            GGML_ABORT("fatal error");
+    for (size_t i = 0; i < res.size(); i++) {
+        tokens[i] = res[i];
     }
 
-    return output;
+    return res.size();
 }
 
-llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch) {
-    GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
-    static const char * hex = "0123456789ABCDEF";
-    switch (llama_vocab_get_type(vocab)) {
-        case LLAMA_VOCAB_TYPE_SPM:
-        case LLAMA_VOCAB_TYPE_UGM: {
-            const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
-            auto token = vocab.token_to_id.find(buf);
-            if (token != vocab.token_to_id.end()) {
-                return (*token).second;
-            }
-            // Try to fall back to just the byte as a string
-            const char buf2[2] = { (char)ch, 0 };
-            return vocab.token_to_id.at(buf2);
-        }
-        case LLAMA_VOCAB_TYPE_WPM:
-        case LLAMA_VOCAB_TYPE_BPE: {
-            return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
-        }
-        default:
-            GGML_ABORT("fatal error");
-    }
+std::vector llama_vocab::tokenize(
+        const std::string & raw_text,
+        bool add_special,
+        bool parse_special) const {
+    return pimpl->tokenize(raw_text, add_special, parse_special);
 }
 
-const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].text.c_str();
+const std::string & llama_vocab::token_to_piece(llama_token token) const {
+    return pimpl->token_to_piece(token);
 }
 
-float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].score;
+int32_t llama_vocab::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
+    return pimpl->token_to_piece(token, buf, length, lstrip, special);
 }
 
-llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].attr;
+int32_t llama_vocab::detokenize(
+               const llama_token * tokens,
+                         int32_t   n_tokens,
+                            char * text,
+                         int32_t   text_len_max,
+                            bool   remove_special,
+                            bool   unparse_special) const {
+    return pimpl->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
 }
 
-bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
-    return token != -1 && (
-        token == llama_token_eos_impl(vocab) ||
-        token == llama_token_eot_impl(vocab) ||
-        token == llama_token_eom_impl(vocab)
-    );
+std::string llama_vocab::detokenize(const std::vector & tokens, bool special) const {
+    std::string text;
+    text.resize(std::max(text.capacity(), tokens.size()));
+    int32_t n_chars = detokenize(tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+    if (n_chars < 0) {
+        text.resize(-n_chars);
+        n_chars = detokenize(tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+        GGML_ASSERT(n_chars <= (int32_t)text.size());  // whitespace trimming is performed after per-token detokenization
+    }
+
+    text.resize(n_chars);
+
+    // NOTE: the original tokenizer decodes bytes after collecting the pieces.
+    return text;
 }
 
-bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
-    return llama_is_control_token(vocab, token);
+void llama_vocab::print_info() const {
+    pimpl->print_info();
 }
 
-llama_token llama_token_bos_impl(const struct llama_vocab & vocab) {
-    return vocab.special_bos_id;
+//
+// interface implementation
+//
+
+int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab) {
+    return vocab->n_tokens();
 }
 
-llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eos_id;
+// deprecated
+int32_t llama_n_vocab(const struct llama_vocab * vocab) {
+    return llama_vocab_n_tokens(vocab);
 }
 
-llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
-    return vocab.special_cls_id;
+enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab) {
+    return vocab->get_type();
 }
 
-llama_token llama_token_sep_impl(const struct llama_vocab & vocab) {
-    return vocab.special_sep_id;
+const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_text(token);
 }
 
-llama_token llama_token_nl_impl(const struct llama_vocab & vocab) {
-    return vocab.linefeed_id;
+float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_score(token);
 }
 
-llama_token llama_token_pad_impl(const struct llama_vocab & vocab) {
-    return vocab.special_pad_id;
+enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_attr(token);
 }
 
-int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab) {
-    return vocab.tokenizer_add_bos;
+bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->is_eog(token);
 }
 
-int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab) {
-    return vocab.tokenizer_add_eos;
+bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->is_control(token);
 }
 
-llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
-    return vocab.special_prefix_id;
+llama_token llama_vocab_bos(const struct llama_vocab * vocab) {
+    return vocab->token_bos();
 }
 
-llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
-    return vocab.special_middle_id;
+llama_token llama_vocab_eos(const struct llama_vocab * vocab) {
+    return vocab->token_eos();
 }
 
-llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
-    return vocab.special_suffix_id;
+llama_token llama_vocab_eot(const struct llama_vocab * vocab) {
+    return vocab->token_eot();
 }
 
-llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_pre_id;
+// deprecated
+llama_token llama_vocab_cls(const struct llama_vocab * vocab) {
+    return vocab->token_bos();
 }
 
-llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_suf_id;
+llama_token llama_vocab_sep(const struct llama_vocab * vocab) {
+    return vocab->token_sep();
 }
 
-llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_mid_id;
+llama_token llama_vocab_nl (const struct llama_vocab * vocab) {
+    return vocab->token_nl();
 }
 
-llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_pad_id;
+llama_token llama_vocab_pad(const struct llama_vocab * vocab) {
+    return vocab->token_pad();
 }
 
-llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_rep_id;
+bool llama_vocab_get_add_bos(const struct llama_vocab * vocab) {
+    return vocab->get_add_bos();
 }
 
-llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_sep_id;
+bool llama_vocab_get_add_eos(const struct llama_vocab * vocab) {
+    return vocab->get_add_eos();
 }
 
-llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eot_id;
+bool llama_vocab_get_add_sep(const struct llama_vocab * vocab) {
+    return vocab->get_add_sep();
 }
 
-llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eom_id;
+llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab) {
+    return vocab->token_fim_pre();
 }
 
-int32_t llama_tokenize_impl(
-    const struct llama_vocab & vocab,
-                  const char * text,
-                     int32_t   text_len,
-                 llama_token * tokens,
-                     int32_t   n_tokens_max,
-                        bool   add_special,
-                        bool   parse_special) {
-    auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special);
-    if (n_tokens_max < (int) res.size()) {
-        // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
-        return -((int) res.size());
-    }
+llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab) {
+    return vocab->token_fim_suf();
+}
 
-    for (size_t i = 0; i < res.size(); i++) {
-        tokens[i] = res[i];
-    }
+llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab) {
+    return vocab->token_fim_mid();
+}
 
-    return res.size();
+llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab) {
+    return vocab->token_fim_pad();
 }
 
-static std::string llama_decode_text(const std::string & text) {
-    std::string decoded_text;
+llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab) {
+    return vocab->token_fim_rep();
+}
 
-    const auto cpts = unicode_cpts_from_utf8(text);
-    for (const auto cpt : cpts) {
-        const auto utf8 = unicode_cpt_to_utf8(cpt);
-        try {
-            decoded_text += unicode_utf8_to_byte(utf8);
-        } catch (const std::out_of_range & /*e*/) {
-            decoded_text += "[UNK_BYTE_0x";
-            for (const auto c : utf8) {
-                decoded_text += format("%02x", (uint8_t) c);
-            }
-            decoded_text += text + "]";
-        }
-    }
+llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) {
+    return vocab->token_fim_sep();
+}
 
-    return decoded_text;
+llama_token llama_vocab_mask(const struct llama_vocab* vocab) {
+    return vocab->token_mask();
 }
 
-// does not write null-terminator to buf
-int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) {
-    // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
-    static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
-    const llama_token_attr attr = llama_token_get_attr_impl(vocab, token);
-    if (!special && (attr & attr_special)) {
-        return 0;
-    }
+// deprecated
+const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_text(vocab, token);
+}
 
-    // copy piece chars to output text buffer
-    // skip up to 'lstrip' leading spaces before copying
-    auto _try_copy = [=] (const char * token, size_t size) -> int32_t {
-        for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) {
-            token++;
-            size--;
-        }
-        if (length < (int32_t)size) {
-            return -(int32_t) size;
-        }
-        memcpy(buf, token, size);
-        return (int32_t) size;
-    };
+// deprecated
+float llama_token_get_score(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_score(vocab, token);
+}
 
-    // if we have a cache - use it
-    {
-        const auto & cache = vocab.cache_token_to_piece;
+// deprecated
+enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_attr(vocab, token);
+}
 
-        if (!cache.empty()) {
-            const auto & result = cache.at(token);
-            return _try_copy(result.data(), result.size());
-        }
-    }
+// deprecated
+bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_is_eog(vocab, token);
+}
 
-    if (0 <= token && token < (int32_t) vocab.id_to_token.size()) {
-        const std::string & token_text = vocab.id_to_token[token].text;
-        switch (llama_vocab_get_type(vocab)) {
-            case LLAMA_VOCAB_TYPE_WPM:
-            case LLAMA_VOCAB_TYPE_SPM:
-            case LLAMA_VOCAB_TYPE_UGM: {
-                // NOTE: we accept all unsupported token types,
-                // suppressing them like CONTROL tokens.
-                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
-                    return _try_copy(token_text.data(), token_text.size());
-                } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
-                    std::string result = token_text;
-                    llama_unescape_whitespace(result);
-                    return _try_copy(result.data(), result.size());
-                } else if (attr & LLAMA_TOKEN_ATTR_BYTE) {
-                    char byte = (char) llama_token_to_byte(vocab, token);
-                    return _try_copy((char*) &byte, 1);
-                }
-                break;
-            }
-            case LLAMA_VOCAB_TYPE_BPE: {
-                // NOTE: we accept all unsupported token types,
-                // suppressing them like CONTROL tokens.
-                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
-                    return _try_copy(token_text.data(), token_text.size());
-                } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
-                    std::string result = llama_decode_text(token_text);
-                    return _try_copy(result.data(), result.size());
-                }
-                break;
-            }
-            default:
-                GGML_ABORT("fatal error");
-        }
-    }
+// deprecated
+bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_is_control(vocab, token);
+}
 
-    return 0;
+// deprecated
+llama_token llama_token_bos(const struct llama_vocab * vocab) {
+    return llama_vocab_bos(vocab);
 }
 
-int32_t llama_detokenize_impl(
-        const struct llama_vocab & vocab,
-               const llama_token * tokens,
-                         int32_t   n_tokens,
-                            char * text,
-                         int32_t   text_len_max,
-                            bool   remove_special,
-                            bool   unparse_special) {
-    int32_t avail = text_len_max;
-    int32_t total = 0;
+// deprecated
+llama_token llama_token_eos(const struct llama_vocab * vocab) {
+    return llama_vocab_eos(vocab);
+}
 
-    // remove the leading space
-    bool remove_space = vocab.tokenizer_add_space_prefix;
+// deprecated
+llama_token llama_token_eot(const struct llama_vocab * vocab) {
+    return llama_vocab_eot(vocab);
+}
 
-    if (remove_special && vocab.tokenizer_add_bos) {
-        if (n_tokens > 0 && tokens[0] == vocab.special_bos_id) {
-            remove_space = false;
-            n_tokens--;
-            tokens++;
-        }
-    }
+// deprecated
+llama_token llama_token_cls(const struct llama_vocab * vocab) {
+    //return llama_vocab_cls(vocab);
+    return llama_vocab_bos(vocab); // avoid deprecation warning
+}
 
-    if (remove_special && vocab.tokenizer_add_eos) {
-        if (n_tokens > 0 && tokens[n_tokens-1] == vocab.special_eos_id) {
-            n_tokens--;
-        }
-    }
+// deprecated
+llama_token llama_token_sep(const struct llama_vocab * vocab) {
+    return llama_vocab_sep(vocab);
+}
 
-    for (int32_t i = 0; i < n_tokens; ++i) {
-        GGML_ASSERT(avail >= 0);
-        int32_t n_chars = llama_token_to_piece_impl(vocab, tokens[i], text, avail, remove_space, unparse_special);
-        remove_space = false;
-        if (n_chars < 0) {
-            avail = 0;
-            total -= n_chars;
-        } else if (n_chars > 0) {
-            avail -= n_chars;
-            text  += n_chars;
-            total += n_chars;
-        }
-    }
+// deprecated
+llama_token llama_token_nl (const struct llama_vocab * vocab) {
+    return llama_vocab_nl(vocab);
+}
 
-    if (total > text_len_max) {
-        return -total;
-    }
+// deprecated
+llama_token llama_token_pad(const struct llama_vocab * vocab) {
+    return llama_vocab_pad(vocab);
+}
 
-    if (vocab.tokenizer_clean_spaces) {
-        text -= total;  // restart text
+// deprecated
+bool llama_add_bos_token(const struct llama_vocab * vocab) {
+    return llama_vocab_get_add_bos(vocab);
+}
 
-        // first pass: characters ?!.,  //TODO: where do these characters come from?
-        const int32_t total1 = total;
-        total = total ? 1 : 0;
-        for (int32_t i = 1; i < total1; ++i) {
-            const char x = text[i];
-            if (text[i - 1] == ' ') {
-                if (x == '?' || x == '!' || x == '.' || x == ',') {  // " ?", " !", " .", " ,"
-                    total--;  // remove space
-                }
-            }
-            text[total++] = x;
-        }
+// deprecated
+bool llama_add_eos_token(const struct llama_vocab * vocab) {
+    return llama_vocab_get_add_eos(vocab);
+}
 
-        // second pass: strip single apostrophe between spaces
-        const int32_t total2 = total;
-        total = total ? 1 : 0;
-        for (int32_t i = 1; i < total2; ++i) {
-            const char x = text[i];
-            if (x == '\'' && i + 1 < total2 && text[i - 1] == ' ' && text[i + 1] == ' ') {  // " ' "
-                total--;           // remove prev space
-                text[++i] = '\0';  // remove next space
-            }
-            text[total++] = x;
-        }
+// deprecated
+llama_token llama_token_fim_pre(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_pre(vocab);
+}
 
-        // third pass: apostrophe contractions  //NOTE: this makes sense?
-        const int32_t total3 = total;
-        total = total ? 1 : 0;
-        for (int32_t i = 1; i < total3; ++i) {
-            const char x = text[i];
-            if (text[i - 1] == ' ') {
-                if (x == '\'' && i + 1 < total3) {
-                    const char x1 = text[i + 1];
-                    if (x1 == 't' || x1 == 'd') {  // " 't", " 'd"
-                        //total--;  // remove space
-                    } else if (x1 == 's' || x1 == 'm') {  // " 's", " 'm"
-                        total--;  // remove space
-                    } else if (i + 2 < total3) {
-                        const char x2 = text[i + 2];
-                        if ((x1 == 'l' && x2 == 'l')) {  // " 'll"
-                            //total--;  // remove space
-                        } else if ((x1 == 'r' && x2 == 'e') || (x1 == 'v' && x2 == 'e')) {  // " 're", " 've"
-                            total--;  // remove space
-                        } else {
-                            //total--;  // remove space
-                        }
-                    } else {
-                        //total--;  // remove space
-                    }
-                }
-            }
-            text[total++] = x;
-        }
-    }
+// deprecated
+llama_token llama_token_fim_suf(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_suf(vocab);
+}
 
-    return total <= text_len_max ? total : -total;
+// deprecated
+llama_token llama_token_fim_mid(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_mid(vocab);
 }
 
-std::string llama_detokenize(const struct llama_vocab& vocab, const std::vector& tokens, bool special) {
-    std::string text;
-    text.resize(std::max(text.capacity(), tokens.size()));
-    int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
-    if (n_chars < 0) {
-        text.resize(-n_chars);
-        n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
-        GGML_ASSERT(n_chars <= (int32_t)text.size());  // whitespace trimming is performed after per-token detokenization
-    }
+// deprecated
+llama_token llama_token_fim_pad(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_pad(vocab);
+}
 
-    text.resize(n_chars);
+// deprecated
+llama_token llama_token_fim_rep(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_rep(vocab);
+}
 
-    // NOTE: the original tokenizer decodes bytes after collecting the pieces.
-    return text;
+// deprecated
+llama_token llama_token_fim_sep(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_sep(vocab);
+}
+
+//
+// tokenization
+//
+
+int32_t llama_tokenize(
+    const struct llama_vocab * vocab,
+                  const char * text,
+                     int32_t   text_len,
+                 llama_token * tokens,
+                     int32_t   n_tokens_max,
+                        bool   add_special,
+                        bool   parse_special) {
+    return vocab->tokenize(text, text_len, tokens, n_tokens_max, add_special, parse_special);
+}
+
+int32_t llama_token_to_piece(
+    const struct llama_vocab * vocab,
+                 llama_token   token,
+                        char * buf,
+                     int32_t   length,
+                     int32_t   lstrip,
+                        bool   special) {
+    return vocab->token_to_piece(token, buf, length, lstrip, special);
+}
+
+int32_t llama_detokenize(
+    const struct llama_vocab * vocab,
+           const llama_token * tokens,
+                     int32_t   n_tokens,
+                        char * text,
+                     int32_t   text_len_max,
+                        bool   remove_special,
+                        bool   unparse_special) {
+    return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
 }
diff --git a/src/llama-vocab.h b/src/llama-vocab.h
index 64ff7cc08..6a37b0c18 100644
--- a/src/llama-vocab.h
+++ b/src/llama-vocab.h
@@ -1,155 +1,178 @@
 #pragma once
 
-#include "llama-impl.h"
+#include "llama.h"
 
 #include 
 #include 
-#include 
-#include 
+#include 
+
+// pre-tokenization types
+enum llama_vocab_pre_type {
+    LLAMA_VOCAB_PRE_TYPE_DEFAULT        = 0,
+    LLAMA_VOCAB_PRE_TYPE_LLAMA3         = 1,
+    LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM   = 2,
+    LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
+    LLAMA_VOCAB_PRE_TYPE_FALCON         = 4,
+    LLAMA_VOCAB_PRE_TYPE_MPT            = 5,
+    LLAMA_VOCAB_PRE_TYPE_STARCODER      = 6,
+    LLAMA_VOCAB_PRE_TYPE_GPT2           = 7,
+    LLAMA_VOCAB_PRE_TYPE_REFACT         = 8,
+    LLAMA_VOCAB_PRE_TYPE_COMMAND_R      = 9,
+    LLAMA_VOCAB_PRE_TYPE_STABLELM2      = 10,
+    LLAMA_VOCAB_PRE_TYPE_QWEN2          = 11,
+    LLAMA_VOCAB_PRE_TYPE_OLMO           = 12,
+    LLAMA_VOCAB_PRE_TYPE_DBRX           = 13,
+    LLAMA_VOCAB_PRE_TYPE_SMAUG          = 14,
+    LLAMA_VOCAB_PRE_TYPE_PORO           = 15,
+    LLAMA_VOCAB_PRE_TYPE_CHATGLM3       = 16,
+    LLAMA_VOCAB_PRE_TYPE_CHATGLM4       = 17,
+    LLAMA_VOCAB_PRE_TYPE_VIKING         = 18,
+    LLAMA_VOCAB_PRE_TYPE_JAIS           = 19,
+    LLAMA_VOCAB_PRE_TYPE_TEKKEN         = 20,
+    LLAMA_VOCAB_PRE_TYPE_SMOLLM         = 21,
+    LLAMA_VOCAB_PRE_TYPE_CODESHELL      = 22,
+    LLAMA_VOCAB_PRE_TYPE_BLOOM          = 23,
+    LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH   = 24,
+    LLAMA_VOCAB_PRE_TYPE_EXAONE         = 25,
+    LLAMA_VOCAB_PRE_TYPE_CHAMELEON      = 26,
+    LLAMA_VOCAB_PRE_TYPE_MINERVA        = 27,
+    LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM  = 28,
+    LLAMA_VOCAB_PRE_TYPE_GPT4O          = 29,
+    LLAMA_VOCAB_PRE_TYPE_SUPERBPE       = 30,
+    LLAMA_VOCAB_PRE_TYPE_TRILLION       = 31,
+    LLAMA_VOCAB_PRE_TYPE_BAILINGMOE     = 32,
+    LLAMA_VOCAB_PRE_TYPE_LLAMA4         = 33,
+    LLAMA_VOCAB_PRE_TYPE_PIXTRAL        = 34,
+    LLAMA_VOCAB_PRE_TYPE_SEED_CODER     = 35,
+    LLAMA_VOCAB_PRE_TYPE_HUNYUAN        = 36,
+    LLAMA_VOCAB_PRE_TYPE_KIMI_K2        = 37,
+    LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE  = 38,
+};
 
-struct llama_vocab {
-    using id    = llama_token;
-    using token = std::string;
-    using tattr = llama_token_attr;
+struct LLM_KV;
+struct llama_model_loader;
 
+struct llama_vocab {
     struct token_data {
-        token text;
-        float score;
-        tattr attr;
+        std::string      text;
+        float            score;
+        llama_token_attr attr;
     };
 
-    enum llama_vocab_type     type     = LLAMA_VOCAB_TYPE_SPM;
-    enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+    llama_vocab();
+    ~llama_vocab();
 
-    int max_token_len = 0; // used for optimizing longest token search
+    void load(llama_model_loader & ml, const LLM_KV & kv);
 
-    uint32_t n_tokens() const;
+    std::string get_tokenizer_model() const;
+    std::string get_tokenizer_pre() const;
 
-    std::unordered_map token_to_id;
-    std::vector       id_to_token;
-
-    std::vector    cache_special_tokens;
-    std::vector cache_token_to_piece; // llama_token_to_piece(special = true);
-
-    std::map, int> bpe_ranks;
-
-    // default LLaMA special tokens
-    id special_bos_id  = 1;
-    id special_eos_id  = 2;
-    id special_unk_id  = 0;
-    id special_sep_id  = -1;
-    id special_pad_id  = -1;
-    id special_cls_id  = -1;
-    id special_mask_id = -1;
-
-    id linefeed_id       = 13;
-
-    // fim tokens
-    llama_token special_fim_pre_id = -1;
-    llama_token special_fim_suf_id = -1;
-    llama_token special_fim_mid_id = -1;
-    llama_token special_fim_pad_id = -1;
-    llama_token special_fim_rep_id = -1; // repo
-    llama_token special_fim_sep_id = -1; // file separator
-
-    id special_prefix_id = -1;
-    id special_suffix_id = -1;
-    id special_middle_id = -1;
-    id special_eot_id    = -1; // TODO: move above after "eos_id", and here add "file separator" token
-    id special_eom_id    = -1;
-
-    // tokenizer flags
-    bool tokenizer_add_space_prefix = false;
-    bool tokenizer_add_bos          = false;
-    bool tokenizer_add_eos          = false;
-    bool tokenizer_ignore_merges    = false;
-    bool tokenizer_clean_spaces     = false;  // clean_up_tokenization_spaces
-    bool tokenizer_remove_extra_whitespaces   = false;
-    bool tokenizer_escape_whitespaces         = true;
-    bool tokenizer_treat_whitespace_as_suffix = false;
-
-    std::vector precompiled_charsmap;
+    enum llama_vocab_type     get_type()     const;
+    enum llama_vocab_pre_type get_pre_type() const;
+
+    uint32_t n_tokens() const;
+    uint32_t n_token_types() const;
+
+    std::string type_name() const;
+
+    bool is_normal      (llama_token id) const;
+    bool is_unknown     (llama_token id) const;
+    bool is_control     (llama_token id) const;
+    bool is_byte        (llama_token id) const;
+    bool is_user_defined(llama_token id) const;
+    bool is_unused      (llama_token id) const;
+    bool is_eog         (llama_token id) const;
+
+    uint8_t     token_to_byte(llama_token id) const;
+    llama_token byte_to_token(uint8_t ch)     const;
+
+    llama_token text_to_token(const std::string & text) const;
+
+    const token_data & get_token_data(llama_token id) const;
+
+    const char *     token_get_text (llama_token id) const;
+    float            token_get_score(llama_token id) const;
+    llama_token_attr token_get_attr (llama_token id) const;
+
+    llama_token token_bos() const;
+    llama_token token_eos() const;
+    llama_token token_eot() const;
+    llama_token token_eom() const;
+    llama_token token_unk() const;
+    llama_token token_sep() const;
+    llama_token token_nl () const;
+    llama_token token_pad() const;
+    llama_token token_mask() const;
+
+    llama_token token_prefix() const;
+    llama_token token_middle() const;
+    llama_token token_suffix() const;
+
+    llama_token token_fim_pre() const;
+    llama_token token_fim_suf() const;
+    llama_token token_fim_mid() const;
+    llama_token token_fim_pad() const;
+    llama_token token_fim_rep() const;
+    llama_token token_fim_sep() const;
+
+    bool get_add_space_prefix          () const;
+    bool get_add_bos                   () const;
+    bool get_add_eos                   () const;
+    bool get_add_sep                   () const;
+    bool get_ignore_merges             () const;
+    bool get_clean_spaces              () const;
+    bool get_remove_extra_whitespaces  () const;
+    bool get_escape_whitespaces        () const;
+    bool get_treat_whitespace_as_suffix() const;
+
+    int max_token_len() const;
 
     int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
+    std::vector get_bpe_merges() const;
+
+    std::vector get_precompiled_charsmap() const;
+
+    int32_t tokenize(
+                   const char * text,
+                      int32_t   text_len,
+                  llama_token * tokens,
+                      int32_t   n_tokens_max,
+                         bool   add_special,
+                         bool   parse_special) const;
+
+    std::vector tokenize(
+            const std::string & raw_text,
+                         bool   add_special,
+                         bool   parse_special = false) const;
+
+    // does not write null-terminator to buf
+    int32_t token_to_piece(
+                  llama_token   token,
+                         char * buf,
+                      int32_t   length,
+                      int32_t   lstrip,
+                         bool   special) const;
+
+    // use cached data
+    const std::string & token_to_piece(llama_token token) const;
+
+    int32_t detokenize(
+            const llama_token * tokens,
+                      int32_t   n_tokens,
+                         char * text,
+                      int32_t   text_len_max,
+                         bool   remove_special,
+                         bool   unparse_special) const;
+
+    std::string detokenize(
+            const std::vector & tokens,
+                                      bool   special) const;
+
+    void print_info() const;
+
+private:
+    struct impl;
+    std::unique_ptr pimpl;
 };
 
 const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
-
-//
-// internal API
-//
-
-// TODO: rename to llama_tokenize_impl
-// TODO: This should probably be in llama.h
-std::vector llama_tokenize_internal(
-        const llama_vocab & vocab,
-        std::string raw_text,
-        bool add_special,
-        bool parse_special = false);
-
-llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
-
-const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
-
-float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token);
-
-llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token);
-
-bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token);
-
-bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token);
-
-llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
-llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
-llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
-llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
-llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
-llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
-
-int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab);
-int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab);
-
-llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab);
-
-llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
-llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
-llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
-llama_token llama_token_eot_impl   (const struct llama_vocab & vocab);
-llama_token llama_token_eom_impl   (const struct llama_vocab & vocab);
-
-int32_t llama_tokenize_impl(
-        const struct llama_vocab & vocab,
-                      const char * text,
-                         int32_t   text_len,
-                     llama_token * tokens,
-                         int32_t   n_tokens_max,
-                            bool   add_special,
-                            bool   parse_special);
-
-// does not write null-terminator to buf
-int32_t llama_token_to_piece_impl(
-        const struct llama_vocab & vocab,
-                     llama_token   token,
-                            char * buf,
-                         int32_t   length,
-                         int32_t   lstrip,
-                            bool   special);
-
-int32_t llama_detokenize_impl(
-        const struct llama_vocab & vocab,
-               const llama_token * tokens,
-                         int32_t   n_tokens,
-                            char * text,
-                         int32_t   text_len_max,
-                            bool   remove_special,
-                            bool   unparse_special);
-
-std::string llama_detokenize(
-    const struct llama_vocab& vocab,
-    const std::vector& tokens,
-    bool   special);
diff --git a/src/llama.cpp b/src/llama.cpp
index fb9331ac5..c5fa10264 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -9,6 +9,9 @@
 #include "llama-vocab.h"
 #include "llama-grammar.h"
 #include "llama-sampling.h"
+#include "llama-arch.h"
+#include "llama-mmap.h"
+#include "llama-model-loader.h"
 
 #include "unicode.h"
 
@@ -178,85 +181,10 @@ static void zeros(std::ofstream & file, size_t n) {
     }
 }
 
-LLAMA_ATTRIBUTE_FORMAT(1, 2)
-static std::string format(const char * fmt, ...) {
-    va_list ap;
-    va_list ap2;
-    va_start(ap, fmt);
-    va_copy(ap2, ap);
-    int size = vsnprintf(NULL, 0, fmt, ap);
-    GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
-    std::vector buf(size + 1);
-    int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
-    GGML_ASSERT(size2 == size);
-    va_end(ap2);
-    va_end(ap);
-    return std::string(buf.data(), size);
-}
-
 //
 // gguf constants (sync with gguf.py)
 //
 
-enum llm_arch {
-    LLM_ARCH_LLAMA,
-    LLM_ARCH_LLAMA4,
-    LLM_ARCH_DECI,
-    LLM_ARCH_FALCON,
-    LLM_ARCH_BAICHUAN,
-    LLM_ARCH_GROK,
-    LLM_ARCH_GPT2,
-    LLM_ARCH_GPTJ,
-    LLM_ARCH_GPTNEOX,
-    LLM_ARCH_MPT,
-    LLM_ARCH_STARCODER,
-    LLM_ARCH_REFACT,
-    LLM_ARCH_BERT,
-    LLM_ARCH_NOMIC_BERT,
-    LLM_ARCH_JINA_BERT_V2,
-    LLM_ARCH_BLOOM,
-    LLM_ARCH_STABLELM,
-    LLM_ARCH_QWEN,
-    LLM_ARCH_QWEN2,
-    LLM_ARCH_QWEN2MOE,
-    LLM_ARCH_QWEN3,
-    LLM_ARCH_QWEN3MOE,
-    LLM_ARCH_PHI2,
-    LLM_ARCH_PHI3,
-    LLM_ARCH_PLAMO,
-    LLM_ARCH_CODESHELL,
-    LLM_ARCH_ORION,
-    LLM_ARCH_INTERNLM2,
-    LLM_ARCH_MINICPM,
-    LLM_ARCH_GEMMA,
-    LLM_ARCH_GEMMA2,
-    LLM_ARCH_GEMMA3,
-    LLM_ARCH_STARCODER2,
-    LLM_ARCH_MAMBA,
-    LLM_ARCH_XVERSE,
-    LLM_ARCH_COMMAND_R,
-    LLM_ARCH_DBRX,
-    LLM_ARCH_OLMO,
-    LLM_ARCH_OPENELM,
-    LLM_ARCH_ARCTIC,
-    LLM_ARCH_DEEPSEEK2,
-    LLM_ARCH_CHATGLM,
-    LLM_ARCH_GLM4,
-    LLM_ARCH_GLM4_MOE,
-    LLM_ARCH_BITNET,
-    LLM_ARCH_BITNET_25,
-    LLM_ARCH_BITNET_B158,
-    LLM_ARCH_T5,
-    LLM_ARCH_T5ENCODER,
-    LLM_ARCH_JAIS,
-    LLM_ARCH_GRANITE,
-    LLM_ARCH_GRANITE_MOE,
-    LLM_ARCH_COHERE2,
-    LLM_ARCH_DOTS1,
-    LLM_ARCH_HUNYUAN_MOE,
-    LLM_ARCH_UNKNOWN,
-};
-
 static const std::map LLM_ARCH_NAMES = {
     { LLM_ARCH_LLAMA,           "llama"        },
     { LLM_ARCH_LLAMA4,          "llama4"       },
@@ -313,126 +241,19 @@ static const std::map LLM_ARCH_NAMES = {
     { LLM_ARCH_COHERE2,         "cohere2"      },
     { LLM_ARCH_DOTS1,            "dots1"       },
     { LLM_ARCH_HUNYUAN_MOE,     "hunyuan-moe"  },
+    { LLM_ARCH_OPENAI_MOE,      "gpt-oss"      },
     { LLM_ARCH_UNKNOWN,         "(unknown)"    },
 };
 
-enum llm_kv {
-    LLM_KV_GENERAL_TYPE,
-    LLM_KV_GENERAL_ARCHITECTURE,
-    LLM_KV_GENERAL_QUANTIZATION_VERSION,
-    LLM_KV_GENERAL_ALIGNMENT,
-    LLM_KV_GENERAL_NAME,
-    LLM_KV_GENERAL_AUTHOR,
-    LLM_KV_GENERAL_VERSION,
-    LLM_KV_GENERAL_URL,
-    LLM_KV_GENERAL_DESCRIPTION,
-    LLM_KV_GENERAL_LICENSE,
-    LLM_KV_GENERAL_SOURCE_URL,
-    LLM_KV_GENERAL_SOURCE_HF_REPO,
-
-    LLM_KV_VOCAB_SIZE,
-    LLM_KV_CONTEXT_LENGTH,
-    LLM_KV_EMBEDDING_LENGTH,
-    LLM_KV_BLOCK_COUNT,
-    LLM_KV_LEADING_DENSE_BLOCK_COUNT,
-    LLM_KV_FEED_FORWARD_LENGTH,
-    LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
-    LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
-    LLM_KV_USE_PARALLEL_RESIDUAL,
-    LLM_KV_TENSOR_DATA_LAYOUT,
-    LLM_KV_EXPERT_COUNT,
-    LLM_KV_EXPERT_USED_COUNT,
-    LLM_KV_EXPERT_SHARED_COUNT,
-    LLM_KV_EXPERT_WEIGHTS_SCALE,
-    LLM_KV_EXPERT_WEIGHTS_NORM,
-    LLM_KV_EXPERT_GATING_FUNC,
-    LLM_KV_NEXTN_PREDICT_LAYERS,
-    LLM_KV_POOLING_TYPE,
-    LLM_KV_LOGIT_SCALE,
-    LLM_KV_DECODER_START_TOKEN_ID,
-    LLM_KV_ATTN_LOGIT_SOFTCAPPING,
-    LLM_KV_FINAL_LOGIT_SOFTCAPPING,
-    LLM_KV_SWIN_NORM,
-    LLM_KV_RESCALE_EVERY_N_LAYERS,
-    LLM_KV_TIME_MIX_EXTRA_DIM,
-    LLM_KV_TIME_DECAY_EXTRA_DIM,
-    LLM_KV_RESIDUAL_SCALE,
-    LLM_KV_EMBEDDING_SCALE,
-    LLM_KV_TOKEN_SHIFT_COUNT,
-    LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
-
-    LLM_KV_ATTENTION_HEAD_COUNT,
-    LLM_KV_ATTENTION_HEAD_COUNT_KV,
-    LLM_KV_ATTENTION_MAX_ALIBI_BIAS,
-    LLM_KV_ATTENTION_CLAMP_KQV,
-    LLM_KV_ATTENTION_KEY_LENGTH,
-    LLM_KV_ATTENTION_VALUE_LENGTH,
-    LLM_KV_ATTENTION_LAYERNORM_EPS,
-    LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
-    LLM_KV_ATTENTION_CAUSAL,
-    LLM_KV_ATTENTION_Q_LORA_RANK,
-    LLM_KV_ATTENTION_KV_LORA_RANK,
-    LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
-    LLM_KV_ATTENTION_SLIDING_WINDOW,
-    LLM_KV_ATTENTION_SCALE,
-
-    LLM_KV_ROPE_DIMENSION_COUNT,
-    LLM_KV_ROPE_FREQ_BASE,
-    LLM_KV_ROPE_SCALE_LINEAR,
-    LLM_KV_ROPE_SCALING_TYPE,
-    LLM_KV_ROPE_SCALING_FACTOR,
-    LLM_KV_ROPE_SCALING_ATTN_FACTOR,
-    LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
-    LLM_KV_ROPE_SCALING_FINETUNED,
-    LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
-
-    LLM_KV_SPLIT_NO,
-    LLM_KV_SPLIT_COUNT,
-    LLM_KV_SPLIT_TENSORS_COUNT,
-
-    LLM_KV_SSM_INNER_SIZE,
-    LLM_KV_SSM_CONV_KERNEL,
-    LLM_KV_SSM_STATE_SIZE,
-    LLM_KV_SSM_TIME_STEP_RANK,
-
-    LLM_KV_TOKENIZER_MODEL,
-    LLM_KV_TOKENIZER_PRE,
-    LLM_KV_TOKENIZER_LIST,
-    LLM_KV_TOKENIZER_TOKEN_TYPE,
-    LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT,
-    LLM_KV_TOKENIZER_SCORES,
-    LLM_KV_TOKENIZER_MERGES,
-    LLM_KV_TOKENIZER_BOS_ID,
-    LLM_KV_TOKENIZER_EOS_ID,
-    LLM_KV_TOKENIZER_UNK_ID,
-    LLM_KV_TOKENIZER_SEP_ID,
-    LLM_KV_TOKENIZER_PAD_ID,
-    LLM_KV_TOKENIZER_CLS_ID,
-    LLM_KV_TOKENIZER_MASK_ID,
-    LLM_KV_TOKENIZER_ADD_BOS,
-    LLM_KV_TOKENIZER_ADD_EOS,
-    LLM_KV_TOKENIZER_ADD_PREFIX,
-    LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
-    LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
-    LLM_KV_TOKENIZER_HF_JSON,
-    LLM_KV_TOKENIZER_RWKV,
-    LLM_KV_TOKENIZER_CHAT_TEMPLATE,
-    LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
-    LLM_KV_TOKENIZER_FIM_PRE_ID,
-    LLM_KV_TOKENIZER_FIM_SUF_ID,
-    LLM_KV_TOKENIZER_FIM_MID_ID,
-    LLM_KV_TOKENIZER_FIM_PAD_ID,
-    LLM_KV_TOKENIZER_FIM_REP_ID,
-    LLM_KV_TOKENIZER_FIM_SEP_ID,
-    LLM_KV_TOKENIZER_PREFIX_ID,
-    LLM_KV_TOKENIZER_SUFFIX_ID,
-    LLM_KV_TOKENIZER_MIDDLE_ID,
-    LLM_KV_TOKENIZER_EOT_ID,
-    LLM_KV_TOKENIZER_EOM_ID,
-
-    LLM_KV_ADAPTER_TYPE,
-    LLM_KV_ADAPTER_LORA_ALPHA,
-};
+llm_arch llm_arch_from_string(const std::string & name) {
+    for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT
+        if (kv.second == name) {
+            return kv.first;
+        }
+    }
+
+    return LLM_ARCH_UNKNOWN;
+}
 
 static const std::map LLM_KV_NAMES = {
     { LLM_KV_GENERAL_TYPE,                  "general.type"                          },
@@ -525,6 +346,7 @@ static const std::map LLM_KV_NAMES = {
     { LLM_KV_TOKENIZER_MASK_ID,              "tokenizer.ggml.mask_token_id"            },
     { LLM_KV_TOKENIZER_ADD_BOS,              "tokenizer.ggml.add_bos_token"            },
     { LLM_KV_TOKENIZER_ADD_EOS,              "tokenizer.ggml.add_eos_token"            },
+    { LLM_KV_TOKENIZER_ADD_SEP,              "tokenizer.ggml.add_sep_token"            },
     { LLM_KV_TOKENIZER_ADD_PREFIX,           "tokenizer.ggml.add_space_prefix"         },
     { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,      "tokenizer.ggml.remove_extra_whitespaces" },
     { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap"     },
@@ -549,109 +371,6 @@ static const std::map LLM_KV_NAMES = {
     { LLM_KV_ADAPTER_LORA_ALPHA,            "adapter.lora.alpha" },
 };
 
-struct LLM_KV {
-    LLM_KV(llm_arch arch, const char* suffix = nullptr);
-
-    llm_arch arch;
-    const char* suffix;
-    std::string operator()(llm_kv kv) const;
-};
-
-enum llm_tensor {
-    LLM_TENSOR_TOKEN_EMBD,
-    LLM_TENSOR_TOKEN_EMBD_NORM,
-    LLM_TENSOR_TOKEN_TYPES,
-    LLM_TENSOR_POS_EMBD,
-    LLM_TENSOR_OUTPUT,
-    LLM_TENSOR_OUTPUT_NORM,
-    LLM_TENSOR_ROPE_FREQS,
-    LLM_TENSOR_ROPE_FACTORS_LONG,
-    LLM_TENSOR_ROPE_FACTORS_SHORT,
-    LLM_TENSOR_ATTN_Q,
-    LLM_TENSOR_ATTN_K,
-    LLM_TENSOR_ATTN_V,
-    LLM_TENSOR_ATTN_QKV,
-    LLM_TENSOR_ATTN_OUT,
-    LLM_TENSOR_ATTN_NORM,
-    LLM_TENSOR_ATTN_NORM_2,
-    LLM_TENSOR_ATTN_OUT_NORM,
-    LLM_TENSOR_ATTN_POST_NORM,
-    LLM_TENSOR_ATTN_ROT_EMBD,
-    LLM_TENSOR_FFN_GATE_INP,
-    LLM_TENSOR_FFN_GATE_INP_SHEXP,
-    LLM_TENSOR_FFN_NORM,
-    LLM_TENSOR_FFN_POST_NORM,
-    LLM_TENSOR_FFN_GATE,
-    LLM_TENSOR_FFN_DOWN,
-    LLM_TENSOR_FFN_UP,
-    LLM_TENSOR_FFN_ACT,
-    LLM_TENSOR_FFN_DOWN_EXP,  // split experts for backward compatibility
-    LLM_TENSOR_FFN_GATE_EXP,
-    LLM_TENSOR_FFN_UP_EXP,
-    LLM_TENSOR_FFN_NORM_EXPS,
-    LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
-    LLM_TENSOR_FFN_GATE_EXPS,
-    LLM_TENSOR_FFN_UP_EXPS,
-    LLM_TENSOR_FFN_DOWN_SHEXP,
-    LLM_TENSOR_FFN_GATE_SHEXP,
-    LLM_TENSOR_FFN_UP_SHEXP,
-    LLM_TENSOR_FFN_EXP_PROBS_B,
-    LLM_TENSOR_ATTN_Q_NORM,
-    LLM_TENSOR_ATTN_K_NORM,
-    LLM_TENSOR_LAYER_OUT_NORM,
-    LLM_TENSOR_SSM_IN,
-    LLM_TENSOR_SSM_CONV1D,
-    LLM_TENSOR_SSM_X,
-    LLM_TENSOR_SSM_DT,
-    LLM_TENSOR_SSM_A,
-    LLM_TENSOR_SSM_D,
-    LLM_TENSOR_SSM_OUT,
-    LLM_TENSOR_ATTN_Q_A,
-    LLM_TENSOR_ATTN_Q_B,
-    LLM_TENSOR_ATTN_KV_A_MQA,
-    LLM_TENSOR_ATTN_KV_B,
-    LLM_TENSOR_ATTN_K_B,
-    LLM_TENSOR_ATTN_V_B,
-    LLM_TENSOR_ATTN_Q_A_NORM,
-    LLM_TENSOR_ATTN_KV_A_NORM,
-    LLM_TENSOR_ATTN_SUB_NORM,
-    LLM_TENSOR_FFN_SUB_NORM,
-    LLM_TENSOR_DEC_ATTN_NORM,
-    LLM_TENSOR_DEC_ATTN_Q,
-    LLM_TENSOR_DEC_ATTN_K,
-    LLM_TENSOR_DEC_ATTN_V,
-    LLM_TENSOR_DEC_ATTN_OUT,
-    LLM_TENSOR_DEC_ATTN_REL_B,
-    LLM_TENSOR_DEC_CROSS_ATTN_NORM,
-    LLM_TENSOR_DEC_CROSS_ATTN_Q,
-    LLM_TENSOR_DEC_CROSS_ATTN_K,
-    LLM_TENSOR_DEC_CROSS_ATTN_V,
-    LLM_TENSOR_DEC_CROSS_ATTN_OUT,
-    LLM_TENSOR_DEC_CROSS_ATTN_REL_B,
-    LLM_TENSOR_DEC_FFN_NORM,
-    LLM_TENSOR_DEC_FFN_GATE,
-    LLM_TENSOR_DEC_FFN_DOWN,
-    LLM_TENSOR_DEC_FFN_UP,
-    LLM_TENSOR_DEC_OUTPUT_NORM,
-    LLM_TENSOR_ENC_ATTN_NORM,
-    LLM_TENSOR_ENC_ATTN_Q,
-    LLM_TENSOR_ENC_ATTN_K,
-    LLM_TENSOR_ENC_ATTN_V,
-    LLM_TENSOR_ENC_ATTN_OUT,
-    LLM_TENSOR_ENC_ATTN_REL_B,
-    LLM_TENSOR_ENC_FFN_NORM,
-    LLM_TENSOR_ENC_FFN_GATE,
-    LLM_TENSOR_ENC_FFN_DOWN,
-    LLM_TENSOR_ENC_FFN_UP,
-    LLM_TENSOR_ENC_OUTPUT_NORM,
-    LLM_TENSOR_NEXTN_EH_PROJ,
-    LLM_TENSOR_NEXTN_EMBED_TOKENS,
-    LLM_TENSOR_NEXTN_ENORM,
-    LLM_TENSOR_NEXTN_HNORM,
-    LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
-    LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
-};
-
 LLM_KV::LLM_KV(llm_arch arch, const char* suffix) : arch(arch), suffix(suffix) {}
 
 std::string LLM_KV::operator()(llm_kv kv) const {
@@ -1732,6 +1451,25 @@ static const std::map> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
         },
     },
+    {
+        LLM_ARCH_OPENAI_MOE,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_POST_NORM,     "blk.%d.post_attention_norm" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_SINKS,         "blk.%d.attn_sinks" },
+            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -1778,6 +1516,7 @@ enum llm_chat_template {
     LLM_CHAT_TEMPLATE_DOTS1,
     LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
     LLM_CHAT_TEMPLATE_KIMI_K2,
+    LLM_CHAT_TEMPLATE_OPENAI_MOE,
     LLM_CHAT_TEMPLATE_UNKNOWN,
 };
 
@@ -1817,20 +1556,10 @@ static const std::map LLM_CHAT_TEMPLATES = {
     { "llama4",            LLM_CHAT_TEMPLATE_LLAMA4            },
     { "hunyuan-moe",       LLM_CHAT_TEMPLATE_HUNYUAN_MOE       },
     { "kimi-k2",           LLM_CHAT_TEMPLATE_KIMI_K2           },
+    { "gpt-oss",           LLM_CHAT_TEMPLATE_OPENAI_MOE        },
     { "bitnet",            LLM_CHAT_TEMPLATE_BITNET            },
 };
 
-
-static llm_arch llm_arch_from_string(const std::string & name) {
-    for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT
-        if (kv.second == name) {
-            return kv.first;
-        }
-    }
-
-    return LLM_ARCH_UNKNOWN;
-}
-
 // helper to handle gguf constants
 // usage:
 //
@@ -1918,7 +1647,7 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int
     }
 }
 
-static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
+std::string gguf_kv_to_str(const gguf_context * ctx_gguf, int i) {
     const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
 
     switch (type) {
@@ -1959,627 +1688,6 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
 // llama helpers
 //
 
-#if defined(_WIN32)
-static std::string llama_format_win_err(DWORD err) {
-    LPSTR buf;
-    size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
-                                 NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
-    if (!size) {
-        return "FormatMessageA failed";
-    }
-    std::string ret(buf, size);
-    LocalFree(buf);
-    return ret;
-}
-#endif
-
-template 
-struct no_init {
-    T value;
-    no_init() { /* do nothing */ }
-};
-
-struct llama_file {
-
-#if defined(_WIN32)
-    // use FILE * so we don't have to re-open the file to mmap
-    FILE * fp;
-    HANDLE fp_win32;
-    size_t size;
-
-private:
-    std::string GetErrorMessageWin32(DWORD error_code) const {
-        std::string ret;
-        LPSTR lpMsgBuf = NULL;
-        DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
-                                    NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
-        if (!bufLen) {
-            ret = format("Win32 error code: %s", error_code);
-        } else {
-            ret = lpMsgBuf;
-            LocalFree(lpMsgBuf);
-        }
-
-        return ret;
-    }
-
-public:
-
-    llama_file(const char * fname, const char * mode) {
-        fp = ggml_fopen(fname, mode);
-        if (fp == NULL) {
-            throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
-        }
-        fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp));
-        seek(0, SEEK_END);
-        size = tell();
-        seek(0, SEEK_SET);
-    }
-
-    size_t tell() const {
-        // SetFilePointerEx returns the current position when seeking relative 0 bytes
-        LARGE_INTEGER li;
-        li.QuadPart = 0;
-        BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT);
-        if (!ret) {
-            throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
-        }
-
-        return li.QuadPart;
-    }
-
-    void seek(size_t offset, int whence) const {
-        // no need to convert SEEK_* to FILE_*. The enums are the same.
-        // Still, keep static asserts to avoid failures in the future.
-        static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN");
-        static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT");
-        static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END");
-
-        LARGE_INTEGER li;
-        li.QuadPart = offset;
-        BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence);
-        if (!ret) {
-            throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
-        }
-    }
-
-    void read_raw(void * ptr, size_t len) const {
-        // On Win32 ReadFile is significant faster than fread which is again significant faster than std::fstream. Thus
-        // use the Win32 API to do file io instead of the C/C++ library functions.
-
-        // There are conditions under which ReadFile cannot read chunks >64MB.
-        // Thus split the operation into smaller chunks if len exceeds this limit.
-        size_t bytes_read = 0;
-        while (bytes_read < len) {
-            size_t chunk_size = std::min(len - bytes_read, 64*1024*1024);
-            DWORD chunk_read = 0;
-            BOOL result = ReadFile(fp_win32, reinterpret_cast(ptr) + bytes_read, chunk_size, &chunk_read, NULL);
-            if (!result) {
-                throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
-            }
-            if (chunk_read < chunk_size || chunk_read == 0) {
-                throw std::runtime_error("unexpectedly reached end of file");
-            }
-
-            bytes_read += chunk_read;
-        } ;
-    }
-
-    uint32_t read_u32() const {
-        uint32_t val;
-        read_raw(&val, sizeof(val));
-        return val;
-    }
-
-    void write_raw(const void * ptr, size_t len) const {
-        // There are conditions under which WriteFile cannot write chunks >64MB.
-        // Thus split the operation into smaller chunks if len exceeds this limit.
-        size_t bytes_written = 0;
-        while (bytes_written < len) {
-            size_t chunk_size = std::min(len - bytes_written, 64*1024*1024);
-            DWORD chunk_written = 0;
-            BOOL result = WriteFile(fp_win32, reinterpret_cast(ptr) + bytes_written, chunk_size, &chunk_written, NULL);
-            if (!result) {
-                throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
-            }
-            if (chunk_written < chunk_size || chunk_written == 0) {
-                throw std::runtime_error("unexpectedly failed to write bytes");
-            }
-
-            bytes_written += chunk_written;
-        }
-    }
-
-    void write_u32(std::uint32_t val) const {
-        write_raw(&val, sizeof(val));
-    }
-
-    ~llama_file() {
-        if (fp) {
-            std::fclose(fp);
-        }
-    }
-#else
-    // use FILE * so we don't have to re-open the file to mmap
-    FILE * fp;
-    size_t size;
-
-    llama_file(const char * fname, const char * mode) {
-        fp = ggml_fopen(fname, mode);
-        if (fp == NULL) {
-            throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
-        }
-        seek(0, SEEK_END);
-        size = tell();
-        seek(0, SEEK_SET);
-    }
-
-    size_t tell() const {
-#ifdef _WIN32
-        __int64 ret = _ftelli64(fp);
-#else
-        long ret = std::ftell(fp);
-#endif
-        if (ret == -1) {
-            throw std::runtime_error(format("ftell error: %s", strerror(errno)));
-        }
-
-        return (size_t) ret;
-    }
-
-    void seek(size_t offset, int whence) const {
-#ifdef _WIN32
-        int ret = _fseeki64(fp, (__int64) offset, whence);
-#else
-        int ret = std::fseek(fp, (long) offset, whence);
-#endif
-        if (ret != 0) {
-            throw std::runtime_error(format("seek error: %s", strerror(errno)));
-        }
-    }
-
-    void read_raw(void * ptr, size_t len) const {
-        if (len == 0) {
-            return;
-        }
-        errno = 0;
-        std::size_t ret = std::fread(ptr, len, 1, fp);
-        if (ferror(fp)) {
-            throw std::runtime_error(format("read error: %s", strerror(errno)));
-        }
-        if (ret != 1) {
-            throw std::runtime_error("unexpectedly reached end of file");
-        }
-    }
-
-    uint32_t read_u32() const {
-        uint32_t ret;
-        read_raw(&ret, sizeof(ret));
-        return ret;
-    }
-
-    void write_raw(const void * ptr, size_t len) const {
-        if (len == 0) {
-            return;
-        }
-        errno = 0;
-        size_t ret = std::fwrite(ptr, len, 1, fp);
-        if (ret != 1) {
-            throw std::runtime_error(format("write error: %s", strerror(errno)));
-        }
-    }
-
-    void write_u32(std::uint32_t val) const {
-        write_raw(&val, sizeof(val));
-    }
-
-    ~llama_file() {
-        if (fp) {
-            std::fclose(fp);
-        }
-    }
-#endif
-};
-using llama_files = std::vector>;
-
-struct llama_mmap {
-    void * addr;
-    size_t size;
-    size_t mapped_page_size = 0;
-
-    llama_mmap(const llama_mmap &) = delete;
-
-#ifdef _POSIX_MAPPED_FILES
-    static constexpr bool SUPPORTED = true;
-
-    // list of mapped fragments (first_offset, last_offset)
-    std::vector> mapped_fragments;
-
-    llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false, [[maybe_unused]] bool use_thp = false) {
-        size = file->size;
-        int fd = fileno(file->fp);
-        int flags = MAP_SHARED;
-        // prefetch/readahead impairs performance on NUMA systems
-        if (numa)  { prefetch = 0; }
-#ifdef __linux__
-        // advise the kernel to read the file sequentially (increases readahead)
-        if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) {
-            LLAMA_LOG_WARN("warning: posix_fadvise(.., POSIX_FADV_SEQUENTIAL) failed: %s\n",
-                    strerror(errno));
-        }
-        if (prefetch) { flags |= MAP_POPULATE; }
-        if (use_thp) {
-            size_t huge = get_default_huge_page_size();
-            auto size = huge*((file->size + huge - 1)/huge);
-            addr = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB, -1, 0);
-            if (addr != MAP_FAILED) {
-                printf("%s: using THP with page size %zu MiB ", __func__, huge/(1024*1024));
-                fflush(stdout);
-                size_t tot = 0;
-                while (tot < file->size) {
-                    auto n_read = pread(fd, static_cast(addr) + tot, file->size - tot, tot);
-                    if (n_read < 0) throw std::runtime_error(format("Reading into mapped huge pages failed at %zu (%s)", tot, strerror(errno)));
-                    printf(".");  fflush(stdout);
-                    tot += n_read;
-                }
-                printf(" done\n");
-                mapped_fragments.emplace_back(0, file->size);
-                mapped_page_size = huge;
-                return;
-            }
-            else {
-                fprintf(stderr, "%s: mmap with huge page size %zu MiB failed (%s)\n", __func__, huge/(1024*1024), strerror(errno));
-            }
-        }
-#endif
-        addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
-        if (addr == MAP_FAILED) { // NOLINT
-            throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
-        }
-
-        if (prefetch > 0) {
-            // advise the kernel to preload the mapped memory
-            if (posix_madvise(addr, std::min(file->size, prefetch), POSIX_MADV_WILLNEED)) {
-                LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n",
-                        strerror(errno));
-            }
-        }
-        if (numa) {
-            // advise the kernel not to use readahead
-            // (because the next page might not belong on the same node)
-            if (posix_madvise(addr, file->size, POSIX_MADV_RANDOM)) {
-                LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n",
-                        strerror(errno));
-            }
-        }
-
-        // initialize list of mapped_fragments
-        mapped_fragments.emplace_back(0, file->size);
-    }
-
-    static void align_range(size_t * first, size_t * last, size_t page_size) {
-        // align first to the next page
-        size_t offset_in_page = *first & (page_size - 1);
-        size_t offset_to_page = offset_in_page == 0 ? 0 : page_size - offset_in_page;
-        *first += offset_to_page;
-
-        // align last to the previous page
-        *last = *last & ~(page_size - 1);
-
-        if (*last <= *first) {
-            *last = *first;
-        }
-    }
-
-    // partially unmap the file in the range [first, last)
-    void unmap_fragment(size_t first, size_t last) {
-        // note: this function must not be called multiple times with overlapping ranges
-        // otherwise, there is a risk of invalidating addresses that have been repurposed for other mappings
-        int page_size = mapped_page_size > 0 ? mapped_page_size : sysconf(_SC_PAGESIZE);
-        align_range(&first, &last, page_size);
-        size_t len = last - first;
-
-        if (len == 0) {
-            return;
-        }
-
-        GGML_ASSERT(first % page_size == 0);
-        GGML_ASSERT(last % page_size == 0);
-        GGML_ASSERT(last > first);
-
-        void * next_page_start = (uint8_t *) addr + first;
-
-        // unmap the range
-        if (munmap(next_page_start, len)) {
-            LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
-        }
-
-        // update the list of mapped fragments to avoid unmapping the same range again in the destructor
-        std::vector> new_mapped_fragments;
-        for (const auto & frag : mapped_fragments) {
-            if (frag.first < first && frag.second > last) {
-                // the range is in the middle of the fragment, split it
-                new_mapped_fragments.emplace_back(frag.first, first);
-                new_mapped_fragments.emplace_back(last, frag.second);
-            } else if (frag.first < first && frag.second > first) {
-                // the range starts in the middle of the fragment
-                new_mapped_fragments.emplace_back(frag.first, first);
-            } else if (frag.first < last && frag.second > last) {
-                // the range ends in the middle of the fragment
-                new_mapped_fragments.emplace_back(last, frag.second);
-            } else if (frag.first >= first && frag.second <= last) {
-                // the range covers the entire fragment
-            } else {
-                // the range is outside the fragment
-                new_mapped_fragments.push_back(frag);
-            }
-        }
-        mapped_fragments = std::move(new_mapped_fragments);
-    }
-
-#ifdef __linux__
-    static int get_default_huge_page_size() {
-        int pg_size = 2048;
-        std::ifstream in("/proc/meminfo");
-        if (in) {
-            std::string line;
-            while (true) {
-                std::getline(in, line);
-                if (in.fail()) break;
-                if (auto pos = line.find("Hugepagesize:"); pos != std::string::npos) {
-                    std::istringstream str(line.data() + pos + 13);
-                    int aux;
-                    str >> aux;
-                    if (!str.fail()) pg_size = aux;
-                    break;
-                }
-            }
-        }
-        return pg_size * 1024;
-    }
-#endif
-
-    ~llama_mmap() {
-        for (const auto & frag : mapped_fragments) {
-            if (munmap((char *) addr + frag.first, frag.second - frag.first)) {
-                LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
-            }
-        }
-    }
-#elif defined(_WIN32)
-    static constexpr bool SUPPORTED = true;
-
-    llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false, [[maybe_unused]] bool use_thp = false) {
-        GGML_UNUSED(numa);
-
-        size = file->size;
-
-        HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp));
-
-        HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
-
-        if (hMapping == NULL) {
-            DWORD error = GetLastError();
-            throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
-        }
-
-        addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
-        DWORD error = GetLastError();
-        CloseHandle(hMapping);
-
-        if (addr == NULL) {
-            throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()));
-        }
-
-        if (prefetch > 0) {
-#if _WIN32_WINNT >= 0x602
-            // PrefetchVirtualMemory is only present on Windows 8 and above, so we dynamically load it
-            BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG);
-            HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll");
-
-            // may fail on pre-Windows 8 systems
-            pPrefetchVirtualMemory = reinterpret_cast (GetProcAddress(hKernel32, "PrefetchVirtualMemory"));
-
-            if (pPrefetchVirtualMemory) {
-                // advise the kernel to preload the mapped memory
-                WIN32_MEMORY_RANGE_ENTRY range;
-                range.VirtualAddress = addr;
-                range.NumberOfBytes = (SIZE_T) std::min(size, prefetch);
-                if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
-                    LLAMA_LOG_WARN("warning: PrefetchVirtualMemory failed: %s\n",
-                            llama_format_win_err(GetLastError()).c_str());
-                }
-            }
-#else
-            throw std::runtime_error("PrefetchVirtualMemory unavailable");
-#endif
-        }
-    }
-
-    void unmap_fragment(size_t first, size_t last) {
-        // not supported
-        GGML_UNUSED(first);
-        GGML_UNUSED(last);
-    }
-
-    ~llama_mmap() {
-        if (!UnmapViewOfFile(addr)) {
-            LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n",
-                    llama_format_win_err(GetLastError()).c_str());
-        }
-    }
-#else
-    static constexpr bool SUPPORTED = false;
-
-    llama_mmap(struct llama_file * file, size_t prefetch = -1, bool numa = false, bool use_thp = false) {
-        GGML_UNUSED(file);
-        GGML_UNUSED(prefetch);
-        GGML_UNUSED(numa);
-        GGML_UNUSED(use_thp);
-
-        throw std::runtime_error("mmap not supported");
-    }
-
-    void unmap_fragment(size_t first, size_t last) {
-        GGML_UNUSED(first);
-        GGML_UNUSED(last);
-
-        throw std::runtime_error("mmap not supported");
-    }
-#endif
-};
-using llama_mmaps = std::vector>;
-
-// Represents some region of memory being locked using mlock or VirtualLock;
-// will automatically unlock on destruction.
-struct llama_mlock {
-    void * addr = NULL;
-    size_t size = 0;
-
-    bool failed_already = false;
-
-    llama_mlock() {}
-    llama_mlock(const llama_mlock &) = delete;
-
-    ~llama_mlock() {
-        if (size) {
-            raw_unlock(addr, size);
-        }
-    }
-
-    void init(void * ptr) {
-        GGML_ASSERT(addr == NULL && size == 0); // NOLINT
-        addr = ptr;
-    }
-
-    void grow_to(size_t target_size) {
-        GGML_ASSERT(addr);
-        if (failed_already) {
-            return;
-        }
-        size_t granularity = lock_granularity();
-        target_size = (target_size + granularity - 1) & ~(granularity - 1);
-        if (target_size > size) {
-            if (raw_lock((uint8_t *) addr + size, target_size - size)) {
-                size = target_size;
-            } else {
-                failed_already = true;
-            }
-        }
-    }
-
-#ifdef _POSIX_MEMLOCK_RANGE
-    static constexpr bool SUPPORTED = true;
-
-    static size_t lock_granularity() {
-        return (size_t) sysconf(_SC_PAGESIZE);
-    }
-
-    #ifdef __APPLE__
-        #define MLOCK_SUGGESTION \
-            "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
-            "decreasing 'vm.global_no_user_wire_amount'.  Also try increasing RLIMIT_MEMLOCK (ulimit -l).\n"
-    #else
-        #define MLOCK_SUGGESTION \
-            "Try increasing RLIMIT_MEMLOCK ('ulimit -l' as root).\n"
-    #endif
-
-    bool raw_lock(const void * addr, size_t size) const {
-        if (!mlock(addr, size)) {
-            return true;
-        }
-
-        char* errmsg = std::strerror(errno);
-        bool suggest = (errno == ENOMEM);
-
-        // Check if the resource limit is fine after all
-        struct rlimit lock_limit;
-        if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) {
-            suggest = false;
-        }
-        if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) {
-            suggest = false;
-        }
-
-        LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
-                size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
-        return false;
-    }
-
-    #undef MLOCK_SUGGESTION
-
-    static void raw_unlock(void * addr, size_t size) {
-        if (munlock(addr, size)) {
-            LLAMA_LOG_WARN("warning: failed to munlock buffer: %s\n", std::strerror(errno));
-        }
-    }
-#elif defined(_WIN32)
-    static constexpr bool SUPPORTED = true;
-
-    static size_t lock_granularity() {
-        SYSTEM_INFO si;
-        GetSystemInfo(&si);
-        return (size_t) si.dwPageSize;
-    }
-
-    bool raw_lock(void * ptr, size_t len) const {
-        for (int tries = 1; ; tries++) {
-            if (VirtualLock(ptr, len)) {
-                return true;
-            }
-            if (tries == 2) {
-                LLAMA_LOG_WARN("warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
-                    len, size, llama_format_win_err(GetLastError()).c_str());
-                return false;
-            }
-
-            // It failed but this was only the first try; increase the working
-            // set size and try again.
-            SIZE_T min_ws_size, max_ws_size;
-            if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) {
-                LLAMA_LOG_WARN("warning: GetProcessWorkingSetSize failed: %s\n",
-                        llama_format_win_err(GetLastError()).c_str());
-                return false;
-            }
-            // Per MSDN: "The maximum number of pages that a process can lock
-            // is equal to the number of pages in its minimum working set minus
-            // a small overhead."
-            // Hopefully a megabyte is enough overhead:
-            size_t increment = len + 1048576;
-            // The minimum must be <= the maximum, so we need to increase both:
-            min_ws_size += increment;
-            max_ws_size += increment;
-            if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) {
-                LLAMA_LOG_WARN("warning: SetProcessWorkingSetSize failed: %s\n",
-                        llama_format_win_err(GetLastError()).c_str());
-                return false;
-            }
-        }
-    }
-
-    static void raw_unlock(void * ptr, size_t len) {
-        if (!VirtualUnlock(ptr, len)) {
-            LLAMA_LOG_WARN("warning: failed to VirtualUnlock buffer: %s\n",
-                    llama_format_win_err(GetLastError()).c_str());
-        }
-    }
-#else
-    static constexpr bool SUPPORTED = false;
-
-    static size_t lock_granularity() {
-        return (size_t) 65536;
-    }
-
-    bool raw_lock(const void * addr, size_t len) const {
-        LLAMA_LOG_WARN("warning: mlock not supported on this system\n");
-        return false;
-    }
-
-    static void raw_unlock(const void * addr, size_t len) {}
-#endif
-};
-using llama_mlocks = std::vector>;
-
 // NOTE: avoid ever using this except for building the token_to_piece caches
 static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
     std::string piece;
@@ -2597,7 +1705,7 @@ static std::string llama_token_to_piece(const struct llama_model * model, llama_
     return piece;
 }
 
-static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer) {
+ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer) {
     ggml_backend_buffer_type_t buft = nullptr;
 
 #if defined(GGML_USE_CUDA)
@@ -2723,14 +1831,17 @@ static const size_t MiB = 1024*kiB;
 static const size_t GiB = 1024*MiB;
 
 enum llm_expert_gating_func_type {
-    LLM_EXPERT_GATING_FUNC_SOFTMAX = 1,
-    LLM_EXPERT_GATING_FUNC_SIGMOID = 2,
+    LLM_EXPERT_GATING_FUNC_TYPE_NONE             = 0,
+    LLM_EXPERT_GATING_FUNC_SOFTMAX               = 1,
+    LLM_EXPERT_GATING_FUNC_SIGMOID               = 2,
+    LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT = 3,
 };
 
 static const char * llama_expert_gating_func_name(llm_expert_gating_func_type type) {
     switch (type) {
         case LLM_EXPERT_GATING_FUNC_SOFTMAX: return "softmax";
         case LLM_EXPERT_GATING_FUNC_SIGMOID: return "sigmoid";
+        case LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT: return "softmax_weight";
         default:                             return "unknown";
     }
 }
@@ -2982,107 +2093,114 @@ struct llama_layer_nextn {
 // TODO: separate into "llama_layer_enc" and "llama_layer_dec"
 struct llama_layer {
     // normalization
-    struct ggml_tensor * attn_norm;
-    struct ggml_tensor * attn_norm_b;
-    struct ggml_tensor * attn_norm_2;
-    struct ggml_tensor * attn_norm_2_b;
-    struct ggml_tensor * attn_q_norm;
-    struct ggml_tensor * attn_q_norm_b;
-    struct ggml_tensor * attn_k_norm;
-    struct ggml_tensor * attn_k_norm_b;
-    struct ggml_tensor * attn_out_norm;
-    struct ggml_tensor * attn_out_norm_b;
-    struct ggml_tensor * attn_q_a_norm;
-    struct ggml_tensor * attn_kv_a_norm;
-    struct ggml_tensor * attn_sub_norm;
-    struct ggml_tensor * attn_post_norm;
-    struct ggml_tensor * ffn_sub_norm;
-    struct ggml_tensor * attn_norm_cross;
-    struct ggml_tensor * attn_norm_enc;
+    struct ggml_tensor * attn_norm = nullptr;
+    struct ggml_tensor * attn_norm_b = nullptr;
+    struct ggml_tensor * attn_norm_2 = nullptr;
+    struct ggml_tensor * attn_norm_2_b = nullptr;
+    struct ggml_tensor * attn_q_norm = nullptr;
+    struct ggml_tensor * attn_q_norm_b = nullptr;
+    struct ggml_tensor * attn_k_norm = nullptr;
+    struct ggml_tensor * attn_k_norm_b = nullptr;
+    struct ggml_tensor * attn_out_norm = nullptr;
+    struct ggml_tensor * attn_out_norm_b = nullptr;
+    struct ggml_tensor * attn_q_a_norm = nullptr;
+    struct ggml_tensor * attn_kv_a_norm = nullptr;
+    struct ggml_tensor * attn_sub_norm = nullptr;
+    struct ggml_tensor * attn_post_norm = nullptr;
+    struct ggml_tensor * ffn_sub_norm = nullptr;
+    struct ggml_tensor * attn_norm_cross = nullptr;
+    struct ggml_tensor * attn_norm_enc = nullptr;
 
     // attention
-    struct ggml_tensor * wq;
-    struct ggml_tensor * wk;
-    struct ggml_tensor * wv;
-    struct ggml_tensor * wo;
-    struct ggml_tensor * wqkv;
-    struct ggml_tensor * wq_a;
-    struct ggml_tensor * wq_b;
-    struct ggml_tensor * wkv_a_mqa;
-    struct ggml_tensor * wkv_b;
-    struct ggml_tensor * wk_b;
-    struct ggml_tensor * wv_b;
-    struct ggml_tensor * wq_cross;
-    struct ggml_tensor * wk_cross;
-    struct ggml_tensor * wv_cross;
-    struct ggml_tensor * wo_cross;
-    struct ggml_tensor * wq_enc;
-    struct ggml_tensor * wk_enc;
-    struct ggml_tensor * wv_enc;
-    struct ggml_tensor * wo_enc;
+    struct ggml_tensor * wq = nullptr;
+    struct ggml_tensor * wk = nullptr;
+    struct ggml_tensor * wv = nullptr;
+    struct ggml_tensor * wo = nullptr;
+    struct ggml_tensor * wqkv = nullptr;
+    struct ggml_tensor * wq_a = nullptr;
+    struct ggml_tensor * wq_b = nullptr;
+    struct ggml_tensor * wkv_a_mqa = nullptr;
+    struct ggml_tensor * wkv_b = nullptr;
+    struct ggml_tensor * wk_b = nullptr;
+    struct ggml_tensor * wv_b = nullptr;
+    struct ggml_tensor * wq_cross = nullptr;
+    struct ggml_tensor * wk_cross = nullptr;
+    struct ggml_tensor * wv_cross = nullptr;
+    struct ggml_tensor * wo_cross = nullptr;
+    struct ggml_tensor * wq_enc = nullptr;
+    struct ggml_tensor * wk_enc = nullptr;
+    struct ggml_tensor * wv_enc = nullptr;
+    struct ggml_tensor * wo_enc = nullptr;
+    struct ggml_tensor * attn_sinks = nullptr;
 
     // attention bias
-    struct ggml_tensor * bq;
-    struct ggml_tensor * bk;
-    struct ggml_tensor * bv;
-    struct ggml_tensor * bo;
-    struct ggml_tensor * bqkv;
+    struct ggml_tensor * bq = nullptr;
+    struct ggml_tensor * bk = nullptr;
+    struct ggml_tensor * bv = nullptr;
+    struct ggml_tensor * bo = nullptr;
+    struct ggml_tensor * bqkv = nullptr;
 
     // relative position bias
-    struct ggml_tensor * attn_rel_b;
-    struct ggml_tensor * attn_rel_b_enc;
-    struct ggml_tensor * attn_rel_b_cross;
+    struct ggml_tensor * attn_rel_b = nullptr;
+    struct ggml_tensor * attn_rel_b_enc = nullptr;
+    struct ggml_tensor * attn_rel_b_cross = nullptr;
 
     // normalization
-    struct ggml_tensor * ffn_norm;
-    struct ggml_tensor * ffn_norm_b;
-    struct ggml_tensor * ffn_post_norm;
-    struct ggml_tensor * layer_out_norm;
-    struct ggml_tensor * layer_out_norm_b;
-    struct ggml_tensor * ffn_norm_exps;
-    struct ggml_tensor * ffn_norm_enc;
+    struct ggml_tensor * ffn_norm = nullptr;
+    struct ggml_tensor * ffn_norm_b = nullptr;
+    struct ggml_tensor * ffn_post_norm = nullptr;
+    struct ggml_tensor * layer_out_norm = nullptr;
+    struct ggml_tensor * layer_out_norm_b = nullptr;
+    struct ggml_tensor * ffn_norm_exps = nullptr;
+    struct ggml_tensor * ffn_norm_enc = nullptr;
 
     // ff
-    struct ggml_tensor * ffn_gate; // w1
-    struct ggml_tensor * ffn_down; // w2
-    struct ggml_tensor * ffn_up;   // w3
-    struct ggml_tensor * ffn_gate_enc;
-    struct ggml_tensor * ffn_down_enc;
-    struct ggml_tensor * ffn_up_enc;
+    struct ggml_tensor * ffn_gate = nullptr; // w1
+    struct ggml_tensor * ffn_down = nullptr; // w2
+    struct ggml_tensor * ffn_up = nullptr;   // w3
+    struct ggml_tensor * ffn_gate_enc = nullptr;
+    struct ggml_tensor * ffn_down_enc = nullptr;
+    struct ggml_tensor * ffn_up_enc = nullptr;
 
     // ff MoE
-    struct ggml_tensor * ffn_gate_inp;
-    struct ggml_tensor * ffn_gate_exps;
-    struct ggml_tensor * ffn_down_exps;
-    struct ggml_tensor * ffn_up_exps ;
+    struct ggml_tensor * ffn_gate_inp = nullptr;
+    struct ggml_tensor * ffn_gate_exps = nullptr;
+    struct ggml_tensor * ffn_down_exps = nullptr;
+    struct ggml_tensor * ffn_up_exps  = nullptr;
+
+    // ff MoE bias
+    struct ggml_tensor * ffn_gate_inp_b = nullptr;
+    struct ggml_tensor * ffn_gate_exps_b = nullptr;
+    struct ggml_tensor * ffn_down_exps_b = nullptr;
+    struct ggml_tensor * ffn_up_exps_b = nullptr;
 
     // ff shared expert (shexp)
-    struct ggml_tensor * ffn_gate_inp_shexp;
-    struct ggml_tensor * ffn_gate_shexp;
-    struct ggml_tensor * ffn_down_shexp;
-    struct ggml_tensor * ffn_up_shexp;
+    struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
+    struct ggml_tensor * ffn_gate_shexp = nullptr;
+    struct ggml_tensor * ffn_down_shexp = nullptr;
+    struct ggml_tensor * ffn_up_shexp = nullptr;
 
     // ff bias
     struct ggml_tensor * ffn_gate_b = nullptr;
     struct ggml_tensor * ffn_down_b = nullptr; // b2
     struct ggml_tensor * ffn_up_b   = nullptr; // b3
-    struct ggml_tensor * ffn_act;
+    struct ggml_tensor * ffn_act = nullptr;
     struct ggml_tensor * ffn_exp_probs_b = nullptr;
 
     // mamba proj
-    struct ggml_tensor * ssm_in;
-    struct ggml_tensor * ssm_x;
-    struct ggml_tensor * ssm_dt;
-    struct ggml_tensor * ssm_out;
+    struct ggml_tensor * ssm_in = nullptr;
+    struct ggml_tensor * ssm_x = nullptr;
+    struct ggml_tensor * ssm_dt = nullptr;
+    struct ggml_tensor * ssm_out = nullptr;
 
     // mamba
-    struct ggml_tensor * ssm_conv1d;
-    struct ggml_tensor * ssm_a;
-    struct ggml_tensor * ssm_d;
+    struct ggml_tensor * ssm_conv1d = nullptr;
+    struct ggml_tensor * ssm_a = nullptr;
+    struct ggml_tensor * ssm_d = nullptr;
 
     // mamba bias
-    struct ggml_tensor * ssm_conv1d_b;
-    struct ggml_tensor * ssm_dt_b;
+    struct ggml_tensor * ssm_conv1d_b = nullptr;
+    struct ggml_tensor * ssm_dt_b = nullptr;
 
     // long rope factors
     struct ggml_tensor * rope_long  = nullptr;
@@ -3090,13 +2208,13 @@ struct llama_layer {
     struct ggml_tensor * rope_freqs = nullptr;
 
     // bitnet scale
-    struct ggml_tensor * wq_scale;
-    struct ggml_tensor * wk_scale;
-    struct ggml_tensor * wv_scale;
-    struct ggml_tensor * wo_scale;
-    struct ggml_tensor * ffn_gate_scale;
-    struct ggml_tensor * ffn_up_scale;
-    struct ggml_tensor * ffn_down_scale;
+    struct ggml_tensor * wq_scale = nullptr;
+    struct ggml_tensor * wk_scale = nullptr;
+    struct ggml_tensor * wv_scale = nullptr;
+    struct ggml_tensor * wo_scale = nullptr;
+    struct ggml_tensor * ffn_gate_scale = nullptr;
+    struct ggml_tensor * ffn_up_scale = nullptr;
+    struct ggml_tensor * ffn_down_scale = nullptr;
 
     struct llama_layer_nextn nextn;
 
@@ -4096,1158 +3214,15 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
 // model loading and saving
 //
 
-enum llama_fver {
-    GGUF_FILE_VERSION_V1 = 1,
-    GGUF_FILE_VERSION_V2 = 2,
-    GGUF_FILE_VERSION_V3 = 3,
-};
-
-static const char * llama_file_version_name(llama_fver version) {
-    switch (version) {
-        case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)";
-        case GGUF_FILE_VERSION_V2: return "GGUF V2";
-        case GGUF_FILE_VERSION_V3: return "GGUF V3 (latest)";
-    }
-
-    return "unknown";
-}
-
-static std::string llama_format_tensor_shape(const std::vector & ne) {
-    char buf[256];
-    snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0));
-    for (size_t i = 1; i < ne.size(); i++) {
-        snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i));
-    }
-    return buf;
-}
-
-static std::string llama_format_tensor_shape(const struct ggml_tensor * t) {
-    char buf[256];
-    snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]);
-    for (int i = 1; i < GGML_MAX_DIMS; i++) {
-        snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]);
-    }
-    return buf;
-}
-
-namespace GGUFMeta {
-    template 
-    struct GKV_Base_Type {
-        static constexpr gguf_type gt = gt_;
-
-        static T getter(const gguf_context * ctx, const int kid) {
-            return gfun(ctx, kid);
-        }
-    };
-
-    template struct GKV_Base;
-
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-
-    template<> struct GKV_Base {
-        static constexpr gguf_type gt = GGUF_TYPE_STRING;
-
-        static std::string getter(const gguf_context * ctx, const int kid) {
-            return gguf_get_val_str(ctx, kid);
-        }
-    };
-
-    struct ArrayInfo {
-        const gguf_type gt;
-        const size_t length;
-        const void * data;
-    };
-
-    template<> struct GKV_Base {
-        public:
-        static constexpr gguf_type gt = GGUF_TYPE_ARRAY;
-        static ArrayInfo getter(const gguf_context *ctx, const int k) {
-            return ArrayInfo {
-                gguf_get_arr_type(ctx, k),
-                size_t(gguf_get_arr_n(ctx, k)),
-                gguf_get_arr_data(ctx, k),
-            };
-        }
-    };
-
-    template
-    class GKV : public GKV_Base {
-        GKV() = delete;
-
-        public:
-        static T get_kv(const gguf_context * ctx, const int k) {
-            const enum gguf_type kt = gguf_get_kv_type(ctx, k);
-
-            if (kt != GKV::gt) {
-                throw std::runtime_error(format("key %s has wrong type %s but expected type %s",
-                    gguf_get_key(ctx, k), gguf_type_name(kt), gguf_type_name(GKV::gt)));
-            }
-            return GKV::getter(ctx, k);
-        }
-
-        static const char * override_type_to_str(const llama_model_kv_override_type ty) {
-            switch (ty) {
-                case LLAMA_KV_OVERRIDE_TYPE_BOOL:  return "bool";
-                case LLAMA_KV_OVERRIDE_TYPE_INT:   return "int";
-                case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float";
-                case LLAMA_KV_OVERRIDE_TYPE_STR:   return "str";
-            }
-            return "unknown";
-        }
-
-        static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override * ovrd) {
-            if (!ovrd) { return false; }
-            if (ovrd->tag == expected_type) {
-                LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ",
-                    __func__, override_type_to_str(ovrd->tag), ovrd->key);
-                switch (ovrd->tag) {
-                    case LLAMA_KV_OVERRIDE_TYPE_BOOL:  {
-                        LLAMA_LOG_INFO("%s\n", ovrd->val_bool ? "true" : "false");
-                    } break;
-                    case LLAMA_KV_OVERRIDE_TYPE_INT:   {
-                        LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->val_i64);
-                    } break;
-                    case LLAMA_KV_OVERRIDE_TYPE_FLOAT: {
-                        LLAMA_LOG_INFO("%.6f\n", ovrd->val_f64);
-                    } break;
-                    case LLAMA_KV_OVERRIDE_TYPE_STR: {
-                        LLAMA_LOG_INFO("%s\n", ovrd->val_str);
-                    } break;
-                    default:
-                        // Shouldn't be possible to end up here, but just in case...
-                        throw std::runtime_error(
-                            format("Unsupported attempt to override %s type for metadata key %s\n",
-                                override_type_to_str(ovrd->tag), ovrd->key));
-                }
-                return true;
-            }
-            LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n",
-                __func__, ovrd->key, override_type_to_str(expected_type), override_type_to_str(ovrd->tag));
-            return false;
-        }
-
-        template
-        static typename std::enable_if::value, bool>::type
-        try_override(OT & target, const struct llama_model_kv_override * ovrd) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) {
-                target = ovrd->val_bool;
-                return true;
-            }
-            return false;
-        }
-
-        template
-        static typename std::enable_if::value && std::is_integral::value, bool>::type
-        try_override(OT & target, const struct llama_model_kv_override * ovrd) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) {
-                target = ovrd->val_i64;
-                return true;
-            }
-            return false;
-        }
-
-        template
-        static typename std::enable_if::value, bool>::type
-        try_override(T & target, const struct llama_model_kv_override * ovrd) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) {
-                target = ovrd->val_f64;
-                return true;
-            }
-            return false;
-        }
-
-        template
-        static typename std::enable_if::value, bool>::type
-        try_override(T & target, const struct llama_model_kv_override * ovrd) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) {
-                target = ovrd->val_str;
-                return true;
-            }
-            return false;
-        }
-
-        static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
-            if (try_override(target, ovrd)) {
-                return true;
-            }
-            if (k < 0) { return false; }
-            target = get_kv(ctx, k);
-            return true;
-        }
-
-        static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
-            return set(ctx, gguf_find_key(ctx, key), target, ovrd);
-        }
-
-        static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
-            return set(ctx, key.c_str(), target, ovrd);
-        }
-    };
-}
-
-using llama_buf_map = std::unordered_map;
-
-// TODO: update when needed or think of some clever automatic way to do this
-static size_t llama_model_max_nodes(const llama_model & /*model*/) {
-    //if (model.arch == LLM_ARCH_LLAMA && model.hparams.n_layer > ??) { // llama-3 405B
-    //    return 32768;
-    //}
-
-    return 65536;
-}
-
-struct llama_model_loader {
-    int n_kv      = 0;
-    int n_tensors = 0;
-    int n_created = 0;
-
-    int64_t n_elements = 0;
-    size_t  n_bytes    = 0;
-
-    bool use_mmap = false;
-    bool check_tensors;
-    bool repack_tensors = false;
-    bool use_thp = false;
-
-    llama_files files;
-    llama_ftype ftype;
-    llama_fver  fver;
-
-    llama_mmaps mappings;
-
-    // Holds information on a model weight
-    struct llama_tensor_weight {
-        uint16_t  idx; // source file index
-        size_t   offs; // tensor data offset in the original file
-
-        ggml_tensor * tensor;
-
-        llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
-            const int tensor_idx = gguf_find_tensor(gguf_ctx, name);
-            offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
-
-            if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size) {
-                throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name));
-            }
-        }
-    };
-    std::vector weights;
-
-    std::unordered_map kv_overrides;
-    const llama_model_tensor_buft_override * tensor_buft_overrides;
-
-    struct gguf_context * meta = NULL;
-    std::vector contexts;
-
-    std::string arch_name;
-    LLM_KV      llm_kv    = LLM_KV(LLM_ARCH_UNKNOWN);
-
-    llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp,
-            const llama_model_kv_override * param_overrides_p,
-            const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
-        int trace = 0;
-        if (getenv("LLAMA_TRACE")) {
-            trace = atoi(getenv("LLAMA_TRACE"));
-        }
-
-        #ifdef _WIN32
-            // Only bump maxstdio if the user really wants large contexts:
-            #if defined(GGML_MAX_CONTEXTS) && (GGML_MAX_CONTEXTS > 512)
-                // Cap at MSVC's hard limit of 8192 - https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/setmaxstdio?view=msvc-160
-                #if (GGML_MAX_CONTEXTS > 8192)
-                    #define _GGML_STDIO_TARGET 8192
-                #else
-                    #define _GGML_STDIO_TARGET GGML_MAX_CONTEXTS
-                #endif
-                int _setmaxstdio_ret = _setmaxstdio(_GGML_STDIO_TARGET);
-                if (_setmaxstdio_ret == -1) {
-                    LLAMA_LOG_INFO("%s: failed to set max stdio to %d. (setmaxstdio returned -1)\n", __func__, _GGML_STDIO_TARGET);
-                } else {
-                    LLAMA_LOG_INFO("%s: max stdio successfully set to %d\n", __func__, _setmaxstdio_ret);
-                }
-            #endif // GGML_MAX_CONTEXTS > 512
-        #endif // _WIN32
-
-        if (param_overrides_p != nullptr) {
-            for (const struct llama_model_kv_override * p = param_overrides_p; p->key[0] != 0; p++) {
-                kv_overrides.insert({std::string(p->key), *p});
-            }
-        }
-
-        tensor_buft_overrides = param_tensor_buft_overrides_p;
-
-        struct ggml_context * ctx = NULL;
-        struct gguf_init_params params = {
-            /*.no_alloc = */ true,
-            /*.ctx      = */ &ctx,
-        };
-
-        meta = gguf_init_from_file(fname.c_str(), params);
-        if (!meta) {
-            throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str()));
-        }
-
-        get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
-        llm_kv = LLM_KV(llm_arch_from_string(arch_name));
-
-        files.emplace_back(new llama_file(fname.c_str(), "rb"));
-        contexts.emplace_back(ctx);
-
-        // Save tensors data offset of the main file.
-        // For subsidiary files, `meta` tensor data offset must not be used,
-        // so we build a unified tensors index for weights.
-        for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
-            weights.emplace_back(files.back().get(), 0, cur->name, meta, cur);
-        }
-        uint16_t n_split = 0;
-        get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false);
-
-        // Load additional GGML contexts
-        if (n_split > 1) {
-            uint16_t idx = 0;
-            get_key(llm_kv(LLM_KV_SPLIT_NO), idx);
-            if (idx != 0) {
-                throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", idx));
-            }
-
-            char split_prefix[PATH_MAX] = {0};
-            if (!llama_split_prefix(split_prefix, sizeof(split_prefix), fname.c_str(), idx, n_split)) {
-                throw std::runtime_error(format("invalid split file: %s", fname.c_str()));
-            }
-
-            if (trace > 0) {
-                LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split);
-            }
-
-            char split_path[PATH_MAX] = {0};
-            for (idx = 1; idx < n_split; idx++) {
-                llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split);
-
-                struct gguf_init_params split_params = {
-                    /*.no_alloc = */ true,
-                    /*.ctx      = */ &ctx,
-                };
-                struct gguf_context * ctx_gguf = gguf_init_from_file(split_path, split_params);
-                if (!ctx_gguf) {
-                    throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path));
-                }
-
-                files.emplace_back(new llama_file(split_path, "rb"));
-                contexts.emplace_back(ctx);
-
-                // Save tensors data offset info of the shard.
-                for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
-                    weights.emplace_back(files.back().get(), idx, cur->name, ctx_gguf, cur);
-                }
-
-                gguf_free(ctx_gguf);
-            }
-
-            get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors);
-
-            // sanity check
-            {
-                const int n_tensors_loaded = (int) weights.size();
-                if (n_tensors != n_tensors_loaded) {
-                    throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded));
-                }
-            }
-
-            LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n",  __func__, n_split - 1);
-        }
-
-        n_kv      = gguf_get_n_kv(meta);
-        n_tensors = weights.size();
-
-        fver = (enum llama_fver) gguf_get_version(meta);
-
-        std::set tensor_names;
-        for (auto & w : weights) {
-            n_elements += ggml_nelements(w.tensor);
-            n_bytes    += ggml_nbytes(w.tensor);
-            // make sure there is no duplicated tensor names
-            const std::string name(w.tensor->name);
-            auto found = tensor_names.find(name);
-            if (found != tensor_names.end()) {
-                throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", w.tensor->name));
-            }
-            tensor_names.insert(name);
-        }
-
-        LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n",
-                __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver));
-
-        // determine file type based on the number of tensors for each quantization and print meta data
-        // TODO: make optional
-        {
-            std::map n_type;
-
-            uint32_t n_type_max = 0;
-            enum ggml_type type_max = GGML_TYPE_F32;
-
-            for (int i = 0; i < n_tensors; i++) {
-                const ggml_tensor * tensor = weights.at(i).tensor;
-                enum ggml_type type = tensor->type;
-
-                n_type[type]++;
-
-                if (n_type_max < n_type[type]) {
-                    n_type_max = n_type[type];
-                    type_max   = type;
-                }
-
-                if (trace > 0) {
-                    const uint16_t sid = weights.at(i).idx;
-                    LLAMA_LOG_INFO("%s: - tensor %4d, split %2d: %32s %-8s [ %s ]\n", __func__, i, sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str());
-                }
-            }
-
-            switch (type_max) {
-                case GGML_TYPE_F32:     ftype = LLAMA_FTYPE_ALL_F32;        break;
-                case GGML_TYPE_F16:     ftype = LLAMA_FTYPE_MOSTLY_F16;     break;
-                case GGML_TYPE_BF16:    ftype = LLAMA_FTYPE_MOSTLY_BF16;    break;
-                case GGML_TYPE_BF16_R16:ftype = LLAMA_FTYPE_MOSTLY_BF16_R16;break;
-                case GGML_TYPE_Q4_0:    ftype = LLAMA_FTYPE_MOSTLY_Q4_0;    break;
-                case GGML_TYPE_Q4_1:    ftype = LLAMA_FTYPE_MOSTLY_Q4_1;    break;
-                case GGML_TYPE_Q5_0:    ftype = LLAMA_FTYPE_MOSTLY_Q5_0;    break;
-                case GGML_TYPE_Q5_1:    ftype = LLAMA_FTYPE_MOSTLY_Q5_1;    break;
-                case GGML_TYPE_Q6_0:    ftype = LLAMA_FTYPE_MOSTLY_Q6_0;    break;
-                case GGML_TYPE_Q8_0:    ftype = LLAMA_FTYPE_MOSTLY_Q8_0;    break;
-                case GGML_TYPE_Q8_KV:   ftype = LLAMA_FTYPE_MOSTLY_Q8_KV;   break;
-                case GGML_TYPE_Q2_K:    ftype = LLAMA_FTYPE_MOSTLY_Q2_K;    break;
-                case GGML_TYPE_Q3_K:    ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M;  break;
-                case GGML_TYPE_Q3_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_R4; break;
-                case GGML_TYPE_Q4_K:    ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M;  break;
-                case GGML_TYPE_Q4_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_R4; break;
-                case GGML_TYPE_Q5_K:    ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M;  break;
-                case GGML_TYPE_Q5_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_R4; break;
-                case GGML_TYPE_Q6_K:    ftype = LLAMA_FTYPE_MOSTLY_Q6_K;    break;
-                case GGML_TYPE_Q6_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_K_R4; break;
-                case GGML_TYPE_Q8_K_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_K_R8; break;
-                case GGML_TYPE_Q8_KV_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV_R8; break;
-                case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break;
-                case GGML_TYPE_IQ2_XXS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4; break;
-                case GGML_TYPE_IQ2_XS:  ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS;  break;
-                case GGML_TYPE_IQ2_XS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS_R4; break;
-                case GGML_TYPE_IQ2_KS:  ftype = LLAMA_FTYPE_MOSTLY_IQ2_KS;  break;
-                case GGML_TYPE_IQ2_S:   ftype = LLAMA_FTYPE_MOSTLY_IQ2_M;   break;
-                case GGML_TYPE_IQ2_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_M_R4;break;
-                case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break;
-                case GGML_TYPE_IQ3_XXS_R4: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4; break;
-                case GGML_TYPE_IQ1_KT:  ftype = LLAMA_FTYPE_MOSTLY_IQ1_KT;  break;
-                case GGML_TYPE_IQ2_KT:  ftype = LLAMA_FTYPE_MOSTLY_IQ2_KT;  break;
-                case GGML_TYPE_IQ3_KT:  ftype = LLAMA_FTYPE_MOSTLY_IQ3_KT;  break;
-                case GGML_TYPE_IQ4_KT:  ftype = LLAMA_FTYPE_MOSTLY_IQ4_KT;  break;
-                case GGML_TYPE_IQ1_S:   ftype = LLAMA_FTYPE_MOSTLY_IQ1_S;   break;
-                case GGML_TYPE_IQ1_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ1_S_R4;break;
-                case GGML_TYPE_IQ1_M_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ1_M_R4;break;
-                case GGML_TYPE_IQ1_M:   ftype = LLAMA_FTYPE_MOSTLY_IQ1_M;   break;
-                case GGML_TYPE_IQ1_BN:  ftype = LLAMA_FTYPE_MOSTLY_IQ1_BN;  break;
-                case GGML_TYPE_IQ2_BN:  ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN;  break;
-                case GGML_TYPE_IQ2_BN_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN_R4;break;
-                case GGML_TYPE_IQ4_NL:  ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL;  break;
-                case GGML_TYPE_IQ4_NL_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL_R4;break;
-                case GGML_TYPE_IQ4_XS_R8:ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS_R8;break;
-                case GGML_TYPE_Q4_0_R8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_R8; break;
-                case GGML_TYPE_Q5_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q5_0_R4; break;
-                case GGML_TYPE_Q6_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_0_R4; break;
-                case GGML_TYPE_Q8_0_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_0_R8; break;
-                case GGML_TYPE_MXFP4:   ftype = LLAMA_FTYPE_MOSTLY_MXFP4;   break;
-                case GGML_TYPE_IQ4_XS:  ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS;  break;
-                case GGML_TYPE_IQ4_KS:  ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS;  break;
-                case GGML_TYPE_IQ4_KS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS_R4;  break;
-                case GGML_TYPE_IQ5_KS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ5_KS_R4;  break;
-                case GGML_TYPE_IQ4_KSS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KSS; break;
-                case GGML_TYPE_IQ5_KS:  ftype = LLAMA_FTYPE_MOSTLY_IQ5_KS;  break;
-                case GGML_TYPE_IQ2_K:   ftype = LLAMA_FTYPE_MOSTLY_IQ2_K;   break;
-                case GGML_TYPE_IQ2_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_K_R4;break;
-                case GGML_TYPE_IQ3_KS:  ftype = LLAMA_FTYPE_MOSTLY_IQ3_KS;  break;
-                case GGML_TYPE_IQ2_KL:  ftype = LLAMA_FTYPE_MOSTLY_IQ2_KL;  break;
-                case GGML_TYPE_IQ3_K:   ftype = LLAMA_FTYPE_MOSTLY_IQ3_K;   break;
-                case GGML_TYPE_IQ3_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ3_K_R4;break;
-                case GGML_TYPE_IQ4_K:   ftype = LLAMA_FTYPE_MOSTLY_IQ4_K;   break;
-                case GGML_TYPE_IQ4_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_K_R4;break;
-                case GGML_TYPE_IQ5_K:   ftype = LLAMA_FTYPE_MOSTLY_IQ5_K;   break;
-                case GGML_TYPE_IQ5_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ5_K_R4;break;
-                case GGML_TYPE_IQ6_K:   ftype = LLAMA_FTYPE_MOSTLY_IQ6_K;   break;
-                case GGML_TYPE_IQ3_S:   ftype = LLAMA_FTYPE_MOSTLY_IQ3_S;   break;
-                case GGML_TYPE_IQ3_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ3_S_R4;break;
-                case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break;
-                case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break;
-                case GGML_TYPE_Q4_0_8_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_8_8; break;
-                default:
-                    {
-                        LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
-                        ftype = LLAMA_FTYPE_ALL_F32;
-                    } break;
-            }
-
-            // this is a way to mark that we have "guessed" the file type
-            ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED);
-
-            {
-                const int kid = gguf_find_key(meta, "general.file_type"); // TODO: use LLM_KV
-                if (kid >= 0) {
-                    ftype = (llama_ftype) gguf_get_val_u32(meta, kid);
-                }
-            }
-
-            LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__);
-
-            for (int i = 0; i < n_kv; i++) {
-                const char * name           = gguf_get_key(meta, i);
-                const enum gguf_type type   = gguf_get_kv_type(meta, i);
-                const std::string type_name =
-                    type == GGUF_TYPE_ARRAY
-                    ? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta, i)), gguf_get_arr_n(meta, i))
-                    : gguf_type_name(type);
-
-                std::string value          = gguf_kv_to_str(meta, i);
-                const size_t MAX_VALUE_LEN = 40;
-                if (value.size() > MAX_VALUE_LEN) {
-                    value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str());
-                }
-                replace_all(value, "\n", "\\n");
-
-                LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str());
-            }
-
-            // print type counts
-            for (auto & kv : n_type) {
-                if (kv.second == 0) {
-                    continue;
-                }
-
-                LLAMA_LOG_INFO("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second);
-            }
-        }
-
-        if (!llama_mmap::SUPPORTED) {
-            LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__);
-            use_mmap = false;
-        }
-        if (repack_tensors) {
-            use_mmap = false;
-        }
-
-        this->use_mmap = use_mmap;
-        this->check_tensors = check_tensors;
-        this->repack_tensors = repack_tensors;
-        this->use_thp = use_thp;
-    }
-
-    ~llama_model_loader() {
-        if (meta) {
-            gguf_free(meta);
-        }
-        for (auto * ctx : contexts) {
-            ggml_free(ctx);
-        }
-    }
-
-    template
-    typename std::enable_if::value, bool>::type
-    get_arr_n(const std::string & key, T & result, const bool required = true) {
-        const int kid = gguf_find_key(meta, key.c_str());
-
-        if (kid < 0) {
-            if (required) {
-                throw std::runtime_error(format("key not found in model: %s", key.c_str()));
-            }
-            return false;
-        }
-
-        struct GGUFMeta::ArrayInfo arr_info =
-            GGUFMeta::GKV::get_kv(meta, kid);
-
-
-        result = arr_info.length;
-        return true;
-    }
-
-    template
-    typename std::enable_if::value, bool>::type
-    get_arr_n(const enum llm_kv kid, T & result, const bool required = true) {
-        return get_arr_n(llm_kv(kid), result, required);
-    }
-
-    template
-    bool get_arr(const std::string & key, std::vector & result, const bool required = true) {
-        const int kid = gguf_find_key(meta, key.c_str());
-
-        if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) {
-            if (required) {
-                throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
-            }
-            return false;
-        }
-
-        struct GGUFMeta::ArrayInfo arr_info =
-            GGUFMeta::GKV::get_kv(meta, kid);
-
-        switch (arr_info.gt) {
-            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break;
-            case GGUF_TYPE_INT32:   GGML_ASSERT(
-                                            (std::is_same::value) ||
-                                            (std::is_same::value));  break;
-            default:
-                throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
-        }
-
-        result.resize(arr_info.length);
-        result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
-
-        return true;
-    }
-
-    template
-    bool get_arr(const std::string & key, std::array & result, const bool required = true) {
-        const int kid = gguf_find_key(meta, key.c_str());
-
-        if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) {
-            if (required) {
-                throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
-            }
-            return false;
-        }
-
-        struct GGUFMeta::ArrayInfo arr_info =
-            GGUFMeta::GKV::get_kv(meta, kid);
-
-        switch (arr_info.gt) {
-            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break;
-            case GGUF_TYPE_INT32:   GGML_ASSERT(
-                                            (std::is_same::value) ||
-                                            (std::is_same::value));  break;
-            default:
-                throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
-        }
-
-        if (arr_info.length > N_MAX) {
-            throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
-        }
-
-        std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
-
-        return true;
-    }
-
-    template
-    bool get_arr(const enum llm_kv kid, T & result, const bool required = true) {
-        return get_arr(llm_kv(kid), result, required);
-    }
-
-    template
-    bool get_key(const std::string & key, T & result, const bool required = true) {
-        auto it = kv_overrides.find(key);
-
-        const struct llama_model_kv_override * override =
-            it != kv_overrides.end() ? &it->second : nullptr;
-
-        const bool found = GGUFMeta::GKV::set(meta, key, result, override);
-
-        if (required && !found) {
-            throw std::runtime_error(format("key not found in model: %s", key.c_str()));
-        }
-
-        return found;
-    }
-
-    template
-    bool get_key(const enum llm_kv kid, T & result, const bool required = true) {
-        return get_key(llm_kv(kid), result, required);
-    }
-
-    // get array of n <= N_MAX elements, or a single element repeated n times
-    template
-    bool get_key_or_arr(const std::string & key, std::array & result, uint32_t n, const bool required = true) {
-        const int kid = gguf_find_key(meta, key.c_str());
-
-        if (kid < 0) {
-            if (required) {
-                throw std::runtime_error(format("key not found in model: %s", key.c_str()));
-            }
-            return false;
-        }
-
-        if (n > N_MAX) {
-            throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str()));
-        }
-
-        if (gguf_get_kv_type(meta, kid) == GGUF_TYPE_ARRAY) {
-            struct GGUFMeta::ArrayInfo arr_info =
-                GGUFMeta::GKV::get_kv(meta, kid);
-
-            if (n != arr_info.length) {
-                throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length));
-            }
-
-            return get_arr(key, result, required);
-        } else {
-            T value;
-
-            bool ok = get_key(key, value, required);
-            if (!ok) {
-                return false;
-            }
-
-            for (uint32_t i = 0; i < n; i++) {
-                result[i] = value;
-            }
-
-            return true;
-        }
-    }
-
-    template
-    bool get_key_or_arr(const enum llm_kv kid, T & result, uint32_t n, const bool required = true) {
-        return get_key_or_arr(llm_kv(kid), result, n, required);
-    }
-
-    std::string get_arch_name() const {
-        return arch_name;
-    }
-
-    enum llm_arch get_arch() const {
-        return llm_kv.arch;
-    }
-
-    const char * get_tensor_name(int i) const {
-        return weights.at(i).tensor->name;
-    }
-
-    const llama_tensor_weight * get_weight(const char * name) const {
-        for (const auto & weight : weights) {
-            if (strcmp(name, weight.tensor->name) == 0) {
-                return &weight;
-            }
-        }
-        return nullptr;
-    }
-
-    const llama_tensor_weight * get_weight(int i) const {
-        return get_weight(get_tensor_name(i));
-    }
-
-    const llama_tensor_weight & require_weight(const char * name) const {
-        const llama_tensor_weight * weight = get_weight(name);
-        if (!weight) {
-            throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name));
-        }
-        return *weight;
-    }
-
-    struct ggml_tensor * get_tensor_meta(const char * name) const {
-        const auto * weight = get_weight(name);
-        if (!weight) {
-            return nullptr;
-        }
-        return weight->tensor;
-    }
-
-    struct ggml_tensor * require_tensor_meta(const char * name) const {
-        struct ggml_tensor * tensor = get_tensor_meta(name);
-        if (!tensor) {
-            throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name));
-        }
-        return tensor;
-    }
-
-    struct ggml_tensor * get_tensor_meta(int i) const {
-        return get_tensor_meta(get_tensor_name(i));
-    }
-
-    struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur, bool duplicated) {
-        struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur);
-        ggml_set_name(tensor, ggml_get_name(cur));
-
-        if (duplicated) {
-            size_data += ggml_nbytes(cur);
-        } else {
-            n_created++;
-        }
-
-        return tensor;
-    }
-
-    const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector & ne, bool required) const {
-        const struct ggml_tensor * cur = get_tensor_meta(name.c_str());
-
-        if (cur == NULL) {
-            if (!required) {
-                return NULL;
-            }
-            throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str()));
-        }
-
-        {
-            bool is_ok = true;
-            for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
-                if ((i < ne.size() && ne[i] != cur->ne[i]) || (i >= ne.size() && cur->ne[i] != 1)) {
-                    is_ok = false;
-                    break;
-                }
-            }
-            if (!is_ok) {
-                throw std::runtime_error(
-                        format("%s: tensor '%s' has wrong shape; expected %s, got %s",
-                            __func__, name.c_str(),
-                            llama_format_tensor_shape(ne).c_str(),
-                            llama_format_tensor_shape(cur).c_str()));
-            }
-        }
-
-        return cur;
-    }
-
-    static const int TENSOR_NOT_REQUIRED = 1 << 0;
-    static const int TENSOR_DUPLICATED   = 1 << 1;
-    static const int TENSOR_SKIP         = 1 << 2;
-
-    struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, int flags = 0) {
-        const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));
-
-        if (cur == NULL) {
-            return NULL;
-        }
-
-        // skip unused tensors
-        if (flags & TENSOR_SKIP) {
-            const size_t nbytes = ggml_nbytes(cur);
-            LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", name.c_str(), nbytes);
-
-            size_data -= nbytes;
-            n_created++;
-
-            return nullptr;
-        }
-
-        return create_tensor_for(ctx, cur, flags & TENSOR_DUPLICATED);
-    }
-
-    struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::vector & ne, size_t offset, bool required = true) {
-        const struct ggml_tensor * cur = check_tensor_dims(name, ne, required);
-
-        if (cur == NULL) {
-            return NULL;
-        }
-
-        if (cur->type != base->type) {
-            throw std::runtime_error(format("%s: tensor '%s' has wrong type; expected %s, got %s", __func__, name.c_str(), ggml_type_name(base->type), ggml_type_name(cur->type)));
-        }
-
-        std::array dims;
-        for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
-            dims[i] = i < ne.size() ? ne[i] : 1;
-        }
-
-        struct ggml_tensor * tensor = ggml_view_4d(ctx, base,
-                                        dims[0], dims[1], dims[2], dims[3],
-                                        cur->nb[1], cur->nb[2], cur->nb[3],
-                                        offset);
-
-        ggml_set_name(tensor, name.c_str());
-
-        n_created++;
-
-        return tensor;
-    }
-
-    void done_getting_tensors() const {
-        if (n_created != n_tensors) {
-            throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
-        }
-    }
-
-    void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr, bool use_thp = false) {
-        if (use_mmap) {
-            mappings.reserve(files.size());
-            mmaps_used.reserve(files.size());
-            for (const auto & file : files) {
-                std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, ggml_is_numa(), use_thp));
-                mmaps_used.emplace_back(mapping->size, 0);
-                if (mlock_mmaps) {
-                    std::unique_ptr mlock_mmap(new llama_mlock());
-                    mlock_mmap->init(mapping->addr);
-                    mlock_mmaps->emplace_back(std::move(mlock_mmap));
-                }
-                mappings.emplace_back(std::move(mapping));
-            }
-        }
-
-        // compute the total size of all tensors for progress reporting
-        for (auto & w : weights) {
-            size_data += ggml_nbytes(w.tensor);
-        }
-    }
-
-    void get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const {
-        GGML_ASSERT(!mappings.empty());
-        const auto & mapping = mappings.at(idx);
-
-        *first = mapping->size;
-        *last  = 0;
-        *addr = mapping->addr;
-        for (ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor; tensor = ggml_get_next_tensor(ctx, tensor)) {
-            try {
-                const auto * weight = get_weight(ggml_get_name(tensor));
-                if (!weight) {
-                    continue;
-                }
-                if (weight->idx != idx) {
-                    continue;
-                }
-                *first = std::min(*first, weight->offs);
-                *last  = std::max(*last,  weight->offs + ggml_nbytes(tensor));
-            } catch(...) {
-                // the tensor is not in the model
-            }
-        }
-    }
-
-    // for backwards compatibility, does not support ggml-backend
-    void load_data_for(struct ggml_tensor * cur) const {
-        const auto & w = require_weight(ggml_get_name(cur));
-
-        if (use_mmap) {
-            const auto & mapping = mappings.at(w.idx);
-            if (cur->data == nullptr) {
-                cur->data = (uint8_t *)mapping->addr + w.offs;
-            } else {
-                memcpy(cur->data, (uint8_t *)mapping->addr + w.offs, ggml_nbytes(cur));
-            }
-        } else {
-            GGML_ASSERT(cur->data != nullptr);
-            GGML_ASSERT(w.idx < files.size());
-            const auto & file = files.at(w.idx);
-            file->seek(w.offs, SEEK_SET);
-            file->read_raw(cur->data, ggml_nbytes(cur));
-        }
-
-        if (check_tensors && !ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) {
-            throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
-        }
-    }
-
-    size_t size_done = 0;
-    size_t size_data = 0;
-    std::vector> mmaps_used;
-
-    // Returns false if cancelled by progress_callback
-    bool load_all_data(
-            struct ggml_context * ctx,
-            llama_buf_map & bufs_mmap,
-            llama_mlocks * lmlocks,
-            llama_progress_callback progress_callback,
-            void * progress_callback_user_data) {
-        GGML_ASSERT(size_data != 0 && "call init_mappings() first");
-
-        std::vector> read_buf;
-        std::vector>> validation_result;
-
-#if defined(GGML_USE_CUDA)
-        // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives.
-        // NVMe raid configurations might require more / larger buffers.
-        constexpr size_t n_buffers = 4;
-        constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB
-
-        std::vector host_buffers;
-        std::vector host_ptrs;
-        std::vector events;
-        size_t buffer_idx = 0; // buffer to use for async loads
-
-        ggml_backend_t cuda_backend = nullptr;
-        if (!use_mmap && !check_tensors) {
-            // When not using mmaped io use async uploads from pinned memory to GPU memory.
-            // First determine if the CUDA backend is active, and if so, determine the device ID.
-            ggml_backend_buffer_t buf = bufs_mmap.count(0) ? bufs_mmap.at(0) : nullptr;
-            if (buf) {
-                ggml_backend_buffer_type_t buffer_type = ggml_backend_buffer_get_type(buf);
-                for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
-                    auto * cuda_buffer_type = ggml_backend_cuda_buffer_type(i);
-                    if (buffer_type == cuda_buffer_type) {
-                        cuda_backend = ggml_backend_cuda_init(i);
-                        break;
-                    }
-                }
-            }
-
-            // If the cuda backend is active create pinned memory buffers and events for synchronisation.
-            if (cuda_backend) {
-                for (size_t idx = 0; idx < n_buffers; ++idx) {
-                    host_buffers.emplace_back(ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buffer_size));
-                    host_ptrs.emplace_back(ggml_backend_buffer_get_base(host_buffers[idx]));
-                    events.emplace_back(ggml_backend_event_new(cuda_backend));
-                }
-            }
-        }
-#endif
-
-        for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) {
-            const auto * weight = get_weight(ggml_get_name(cur));
-            if (weight == nullptr) {
-                // this can happen with split experts models
-                continue;
-            }
-
-            if (progress_callback) {
-                if (!progress_callback((float) size_done / size_data, progress_callback_user_data)) {
-                    return false;
-                }
-            }
-
-            size_t n_size = ggml_nbytes(cur);
-
-            if (use_mmap) {
-                const auto & mapping = mappings.at(weight->idx);
-                ggml_backend_buffer_t buf_mmap = nullptr;
-                if (bufs_mmap.count(weight->idx)) {
-                    buf_mmap = bufs_mmap.at(weight->idx);
-                }
-                uint8_t * data = (uint8_t *) mapping->addr + weight->offs;
-
-                if (check_tensors) {
-                    validation_result.emplace_back(std::async(std::launch::async, [cur, data, n_size] {
-                        return std::make_pair(cur, ggml_validate_row_data(cur->type, data, n_size));
-                    }));
-                }
-
-                GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated
-                if (buf_mmap && cur->data == nullptr) {
-                    ggml_backend_tensor_alloc(buf_mmap, cur, data);
-                    if (lmlocks) {
-                        const auto & lmlock = lmlocks->at(weight->idx);
-                        lmlock->grow_to(weight->offs + n_size);
-                    }
-
-                    auto & mmap_used = mmaps_used[weight->idx];
-                    mmap_used.first  = std::min(mmap_used.first,  weight->offs);
-                    mmap_used.second = std::max(mmap_used.second, weight->offs + n_size);
-                } else {
-                    ggml_backend_tensor_set(cur, data, 0, n_size);
-                }
-            } else {
-                GGML_ASSERT(weight->idx < files.size());
-                const auto & file = files.at(weight->idx);
-                if (ggml_backend_buffer_is_host(cur->buffer)) {
-                    file->seek(weight->offs, SEEK_SET);
-                    file->read_raw(cur->data, n_size);
-                    if (check_tensors) {
-                        validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] {
-                            return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size));
-                        }));
-                    }
-                } else {
-#if defined(GGML_USE_CUDA)
-                    // If cuda_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU.
-                    if (cuda_backend) {
-                        file->seek(weight->offs, SEEK_SET);
-
-                        size_t bytes_read = 0;
-
-                        while (bytes_read < n_size) {
-                            size_t read_iteration = std::min(buffer_size, n_size - bytes_read);
-
-                            ggml_backend_event_synchronize(events[buffer_idx]);
-                            file->read_raw(host_ptrs[buffer_idx], read_iteration);
-                            ggml_backend_tensor_set_async(cuda_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration);
-                            ggml_backend_event_record(events[buffer_idx]);
-
-                            bytes_read += read_iteration;
-                            ++buffer_idx;
-                            buffer_idx %= n_buffers;
-                        }
-                    }
-                    else
-#endif
-                    {
-                        read_buf.resize(n_size);
-                        file->seek(weight->offs, SEEK_SET);
-                        file->read_raw(read_buf.data(), n_size);
-                        ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
-                        if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
-                            throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
-                        }
-                    }
-                }
-            }
-
-            size_done += n_size;
-        }
-
-#if defined(GGML_USE_CUDA)
-        // free temporary resources used for async cuda uploads
-        if (cuda_backend) {
-            for (size_t idx = 0; idx < n_buffers;++idx) {
-                ggml_backend_event_synchronize(events[idx]);
-                ggml_backend_event_free(events[idx]);
-                ggml_backend_buffer_free(host_buffers[idx]);
-            }
-            ggml_backend_free(cuda_backend);
-        }
-#endif
-
-        // check validation results
-        bool validation_failed = false;
-        for (auto & future : validation_result) {
-            auto result = future.get();
-            if (!result.second) {
-                LLAMA_LOG_ERROR("%s: tensor '%s' has invalid data\n", __func__, ggml_get_name(result.first));
-                validation_failed = true;
-            }
-        }
-        if (validation_failed) {
-            throw std::runtime_error("found tensors with invalid data");
-        }
-
-        // check if this is the last call and do final cleanup
-        if (size_done >= size_data) {
-            // unmap offloaded tensors and metadata
-            if (use_mmap) {
-                for (uint32_t idx = 0; idx < mappings.size(); idx++) {
-                    const auto & mmap_used = mmaps_used.at(idx);
-                    auto & mapping = mappings.at(idx);
-                    mapping->unmap_fragment(0, mmap_used.first);
-                    if (mmap_used.second != 0) {
-                        mapping->unmap_fragment(mmap_used.second, mapping->size);
-                    }
-                }
-            }
-            if (progress_callback) {
-                // Even though the model is done loading, we still honor
-                // cancellation since we need to free allocations.
-                return progress_callback(1.0f, progress_callback_user_data);
-            }
-        }
-
-        return true;
-    }
-};
+// TODO: update when needed or think of some clever automatic way to do this
+static size_t llama_model_max_nodes(const llama_model & /*model*/) {
+    //if (model.arch == LLM_ARCH_LLAMA && model.hparams.n_layer > ??) { // llama-3 405B
+    //    return 32768;
+    //}
 
-template<>
-bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
-    uint32_t tmp;
-    const bool found = get_key(kid, tmp, required);
-    if (found) {
-        result = (enum llama_pooling_type) tmp;
-    } else {
-        result = LLAMA_POOLING_TYPE_UNSPECIFIED;
-    }
-    return found;
+    return 65536;
 }
 
-
 //
 // load LLaMA models
 //
@@ -6283,728 +4258,54 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
-	case LLM_ARCH_DOTS1:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT,   hparams.n_layer_dense_lead);
-                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,  hparams.n_ff_exp);
-                ml.get_key(LLM_KV_EXPERT_SHARED_COUNT,         hparams.n_expert_shared);
-                ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE,        hparams.expert_weights_scale);
-                ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM,         hparams.expert_weights_norm, false);
-                ml.get_key(LLM_KV_EXPERT_GATING_FUNC,          hparams.expert_gating_func, false);
-                switch (hparams.n_layer) {
-                    case 62: model.type = e_model::MODEL_142B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_HUNYUAN_MOE:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,       hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        hparams.n_ff_exp);
-                ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
-
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_80B_A13B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        default: (void)0;
-    }
-
-    model.ftype = ml.ftype;
-
-    if (hparams.f_max_alibi_bias > 0.0f) {
-        hparams.use_alibi = true;
-    }
-
-    hparams.rope_type = llama_rope_type(&model);
-}
-
-static void llm_load_vocab(
-        llama_model_loader & ml,
-        llama_model & model) {
-    auto & vocab = model.vocab;
-
-    struct gguf_context * ctx = ml.meta;
-
-    const auto kv = LLM_KV(model.arch);
-
-    // determine vocab type
-    {
-        std::string tokenizer_model;
-        std::string tokenizer_pre;
-
-        ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
-        ml.get_key(LLM_KV_TOKENIZER_PRE,   tokenizer_pre, false);
-
-        if (tokenizer_model == "no_vocab") {
-            vocab.type = LLAMA_VOCAB_TYPE_NONE;
-
-            // default special tokens
-            vocab.special_bos_id  = -1;
-            vocab.special_eos_id  = -1;
-            vocab.special_unk_id  = -1;
-            vocab.special_sep_id  = -1;
-            vocab.special_pad_id  = -1;
-            vocab.special_cls_id  = -1;
-            vocab.special_mask_id = -1;
-            vocab.linefeed_id     = -1;
-
-            return;
-        } else if (tokenizer_model == "llama") {
-            vocab.type = LLAMA_VOCAB_TYPE_SPM;
-
-            // default special tokens
-            vocab.special_bos_id  = 1;
-            vocab.special_eos_id  = 2;
-            vocab.special_unk_id  = 0;
-            vocab.special_sep_id  = -1;
-            vocab.special_pad_id  = -1;
-            vocab.special_cls_id  = -1;
-            vocab.special_mask_id = -1;
-        } else if (tokenizer_model == "bert") {
-            vocab.type = LLAMA_VOCAB_TYPE_WPM;
-
-            // default special tokens
-            vocab.special_bos_id  = -1;
-            vocab.special_eos_id  = -1;
-            vocab.special_unk_id  = 100;
-            vocab.special_sep_id  = 102;
-            vocab.special_pad_id  = 0;
-            vocab.special_cls_id  = 101;
-            vocab.special_mask_id = 103;
-        } else if (tokenizer_model == "gpt2") {
-            vocab.type = LLAMA_VOCAB_TYPE_BPE;
-
-            // read bpe merges and populate bpe ranks
-            const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
-            if (merges_keyidx == -1) {
-                throw std::runtime_error("cannot find tokenizer merges in model file\n");
-            }
-
-            const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
-            for (int i = 0; i < n_merges; i++) {
-                const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
-                GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
-
-                std::string first;
-                std::string second;
-
-                const size_t pos = word.find(' ', 1);
-
-                if (pos != std::string::npos) {
-                    first  = word.substr(0, pos);
-                    second = word.substr(pos + 1);
-                }
-
-                vocab.bpe_ranks.emplace(std::make_pair(first, second), i);
-            }
-
-            // default special tokens
-            if(model.arch == LLM_ARCH_DOTS1) {
-                vocab.special_bos_id = -1;
-            }
-            else {
-                vocab.special_bos_id  = 11;
-            }
-            vocab.special_eos_id  = 11;
-            vocab.special_unk_id  = -1;
-            vocab.special_sep_id  = -1;
-            vocab.special_pad_id  = -1;
-            vocab.special_cls_id  = -1;
-            vocab.special_mask_id = -1;
-        } else if (tokenizer_model == "t5") {
-            vocab.type = LLAMA_VOCAB_TYPE_UGM;
-
-            // default special tokens
-            vocab.special_bos_id  = -1;
-            vocab.special_eos_id  = 1;
-            vocab.special_unk_id  = 2;
-            vocab.special_sep_id  = -1;
-            vocab.special_pad_id  = 0;
-            vocab.special_cls_id  = -1;
-            vocab.special_mask_id = -1;
-
-            const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
-            if (precompiled_charsmap_keyidx != -1) {
-                size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
-                const char * precompiled_charsmap = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
-                vocab.precompiled_charsmap.assign(precompiled_charsmap, precompiled_charsmap + n_precompiled_charsmap);
-#ifdef IS_BIG_ENDIAN
-                // correct endiannes of data in precompiled_charsmap binary blob
-                uint32_t * xcda_blob_size = (uint32_t *) &vocab.precompiled_charsmap[0];
-                *xcda_blob_size = __builtin_bswap32(*xcda_blob_size);
-                assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap);
-                size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t);
-                uint32_t * xcda_array = (uint32_t *) &vocab.precompiled_charsmap[sizeof(uint32_t)];
-                for (size_t i = 0; i < xcda_array_size; ++i) {
-                    xcda_array[i] = __builtin_bswap32(xcda_array[i]);
-                }
-#endif
-            }
-        } else {
-            throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
-        }
-
-        // for now, only BPE models have pre-tokenizers
-        if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
-            vocab.tokenizer_add_space_prefix = false;
-            vocab.tokenizer_clean_spaces = true;
-            if (tokenizer_pre.empty()) {
-                //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-                // OK - I don't feel like recreati8ng the LLaMA-v3 models. Considering that, at least for now,
-                // LLaMA-v3 is the only model wehere we end up here, let's just force the pre-tokanizer to be
-                // llama3.
-                tokenizer_pre = "llama3";
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
-                LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'llama3'\n", __func__);
-                LLAMA_LOG_WARN("%s:                                             \n", __func__);
-                LLAMA_LOG_WARN("%s: ************************************        \n", __func__);
-                LLAMA_LOG_WARN("%s: GENERATION QUALITY MAY BE DEGRADED!         \n", __func__);
-                LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL             \n", __func__);
-                LLAMA_LOG_WARN("%s: ************************************        \n", __func__);
-                LLAMA_LOG_WARN("%s:                                             \n", __func__);
-                //vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-                //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-            } else if (tokenizer_pre == "default") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            } else if (
-                    tokenizer_pre == "llama3"   ||
-                    tokenizer_pre == "llama-v3" ||
-                    tokenizer_pre == "llama-bpe") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
-                vocab.tokenizer_ignore_merges = true;
-                vocab.tokenizer_add_bos = true;
-            } else if (
-                    tokenizer_pre == "deepseek-llm") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                    tokenizer_pre == "deepseek-coder") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                    tokenizer_pre == "deepseek-v3") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                    tokenizer_pre == "falcon") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON;
-            } else if (tokenizer_pre == "falcon3") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON_3;
-            } else if (tokenizer_pre == "falcon_e") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON_E;
-            } else if (
-                    tokenizer_pre == "mpt") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_MPT;
-            } else if (
-                    tokenizer_pre == "starcoder") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STARCODER;
-            } else if (
-                    tokenizer_pre == "gpt-2"   ||
-                    tokenizer_pre == "phi-2"   ||
-                    tokenizer_pre == "jina-es" ||
-                    tokenizer_pre == "jina-de" ||
-                    tokenizer_pre == "jina-v2-es" ||
-                    tokenizer_pre == "jina-v2-de" ||
-                    tokenizer_pre == "jina-v2-code") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2;
-            } else if (
-                    tokenizer_pre == "refact") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_REFACT;
-            } else if (
-                tokenizer_pre == "command-r") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "qwen2" || tokenizer_pre == "deepseek-r1-qwen") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "stablelm2") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STABLELM2;
-            } else if (
-                tokenizer_pre == "olmo") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_OLMO;
-            } else if (
-                tokenizer_pre == "dbrx") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX;
-            } else if (
-                tokenizer_pre == "smaug-bpe") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG;
-            } else if (
-                tokenizer_pre == "poro-chat") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "glm4" ||
-                tokenizer_pre == "chatglm-bpe") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
-                vocab.special_bos_id  = -1;
-            } else if (
-                tokenizer_pre == "viking") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "jais") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
-            } else if (
-                tokenizer_pre == "tekken") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TEKKEN;
-                vocab.tokenizer_clean_spaces = false;
-                vocab.tokenizer_ignore_merges = true;
-                vocab.tokenizer_add_bos = true;
-            } else if (
-                tokenizer_pre == "smollm") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMOLLM;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "codeshell") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL;
-            } else if (
-                tokenizer_pre == "gpt-4o" ||
-                tokenizer_pre == "llama4") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT4O;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "superbpe") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SUPERBPE;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "trillion") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TRILLION;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "bailingmoe") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "seed-coder") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "hunyuan") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "kimi-k2") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
-                vocab.tokenizer_clean_spaces = false;
-            } else {
-                throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
-            }
-        } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            vocab.tokenizer_add_space_prefix = true;
-            vocab.tokenizer_clean_spaces = false;
-            vocab.tokenizer_add_bos = true;
-            vocab.tokenizer_add_eos = false;
-        } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            vocab.tokenizer_add_space_prefix = false;
-            vocab.tokenizer_clean_spaces = true;
-            vocab.tokenizer_add_bos = true;
-            vocab.tokenizer_add_eos = false;
-        } else if (vocab.type == LLAMA_VOCAB_TYPE_UGM) {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            vocab.tokenizer_add_bos = false;
-            vocab.tokenizer_add_eos = true;
-        } else {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-        }
-
-        ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX,      vocab.tokenizer_add_space_prefix,         false);
-        ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.tokenizer_remove_extra_whitespaces, false);
-    }
-
-    const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
-    if (token_idx == -1) {
-        throw std::runtime_error("cannot find tokenizer vocab in model file\n");
-    }
-
-    const float * scores = nullptr;
-    const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
-    if (score_idx != -1) {
-        scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
-    }
-
-    const int * toktypes = nullptr;
-    const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
-    if (toktype_idx != -1) {
-        toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
-    }
-
-    const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
-
-    vocab.id_to_token.resize(n_vocab);
-
-    for (uint32_t i = 0; i < n_vocab; i++) {
-        std::string word = gguf_get_arr_str(ctx, token_idx, i);
-        GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
-
-        vocab.token_to_id[word] = i;
-        vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size());
-
-        auto & token_data = vocab.id_to_token[i];
-        token_data.text  = std::move(word);
-        token_data.score = scores ? scores[i] : 0.0f;
-        token_data.attr  = LLAMA_TOKEN_ATTR_NORMAL;
-
-        if (toktypes) {  //TODO: remove, required until per token attributes are available from GGUF file
-            switch(toktypes[i]) {
-                case LLAMA_TOKEN_TYPE_UNKNOWN:      token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN;      break;
-                case LLAMA_TOKEN_TYPE_UNUSED:       token_data.attr = LLAMA_TOKEN_ATTR_UNUSED;       break;
-                case LLAMA_TOKEN_TYPE_NORMAL:       token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;       break;
-                case LLAMA_TOKEN_TYPE_CONTROL:      token_data.attr = LLAMA_TOKEN_ATTR_CONTROL;      break;
-                case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break;
-                case LLAMA_TOKEN_TYPE_BYTE:         token_data.attr = LLAMA_TOKEN_ATTR_BYTE;         break;
-                case LLAMA_TOKEN_TYPE_UNDEFINED:    token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED;    break;
-                default:                            token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED;    break;
-            }
-        }
-    }
-    GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
-
-    // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
-    if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
-        // For Fill-In-the-Middle (FIM)/infill models which where converted
-        // prior to support of FIM special tokens in GGUF, the following
-        // will allow those models to continue to work. The general names
-        // of the known models are currently CodeLlama (LLM_ARCH_LLAMA) and
-        // CodeGemma (LLM_ARCH_GEMMA). This can potentially be removed once
-        // new versions of these models have been published.
-        std::string gen_name;
-        ml.get_key(LLM_KV_GENERAL_NAME, gen_name, false);
-
-        std::transform(gen_name.begin(), gen_name.end(), gen_name.begin(),
-            [](unsigned char c){ return std::tolower(c); });
-
-        if (gen_name.find("code") != std::string::npos) {
-            if (model.arch == LLM_ARCH_LLAMA
-              && 32010 < vocab.id_to_token.size()
-              && vocab.id_to_token[32007].text.find("
") != std::string::npos
-              && vocab.id_to_token[32008].text.find("") != std::string::npos
-              && vocab.id_to_token[32009].text.find("") != std::string::npos
-              && vocab.id_to_token[32010].text.find("") != std::string::npos) {
-                vocab.special_prefix_id = 32007;
-                vocab.special_suffix_id = 32008;
-                vocab.special_middle_id = 32009;
-                vocab.special_eot_id    = 32010;
-            } else if (model.arch == LLM_ARCH_GEMMA
-              && 107 < vocab.id_to_token.size()
-              && vocab.id_to_token[67].text == "<|fim_prefix|>"
-              && vocab.id_to_token[69].text == "<|fim_suffix|>"
-              && vocab.id_to_token[68].text == "<|fim_middle|>"
-              && vocab.id_to_token[107].text == "") {
-                vocab.special_prefix_id = 67;
-                vocab.special_suffix_id = 69;
-                vocab.special_middle_id = 68;
-                // TODO: this is not EOT, it is "file separator" token, needs fix
-                //       https://huggingface.co/google/codegemma-7b-it/blob/9b1d9231388358c04d90bd003458f5070d97db44/tokenizer_config.json#L565-L572
-                //vocab.special_eot_id    = 70;
-                vocab.special_eot_id    = 107;
-            }
-        }
-        try {
-            vocab.linefeed_id = llama_byte_to_token_impl(vocab, '\n');
-        } catch (const std::exception & e) {
-            LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what());
-            vocab.linefeed_id = vocab.special_pad_id;
-        }
-    } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
-        vocab.linefeed_id = vocab.special_pad_id;
-    } else {
-        const std::vector ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A
-        GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
-        vocab.linefeed_id = ids[0];
-    }
-
-    // special tokens
-    {
-        const std::vector> special_token_types = {
-            { LLM_KV_TOKENIZER_BOS_ID,    vocab.special_bos_id    },
-            { LLM_KV_TOKENIZER_EOS_ID,    vocab.special_eos_id    },
-            { LLM_KV_TOKENIZER_EOT_ID,    vocab.special_eot_id    },
-            { LLM_KV_TOKENIZER_EOM_ID,    vocab.special_eom_id    },
-            { LLM_KV_TOKENIZER_UNK_ID,    vocab.special_unk_id    },
-            { LLM_KV_TOKENIZER_SEP_ID,    vocab.special_sep_id    },
-            { LLM_KV_TOKENIZER_PAD_ID,    vocab.special_pad_id    },
-            { LLM_KV_TOKENIZER_CLS_ID,    vocab.special_cls_id    },
-            { LLM_KV_TOKENIZER_MASK_ID,   vocab.special_mask_id   },
-
-            { LLM_KV_TOKENIZER_FIM_PRE_ID, vocab.special_fim_pre_id },
-            { LLM_KV_TOKENIZER_FIM_SUF_ID, vocab.special_fim_suf_id },
-            { LLM_KV_TOKENIZER_FIM_MID_ID, vocab.special_fim_mid_id },
-            { LLM_KV_TOKENIZER_FIM_PAD_ID, vocab.special_fim_pad_id },
-            { LLM_KV_TOKENIZER_FIM_REP_ID, vocab.special_fim_rep_id },
-            { LLM_KV_TOKENIZER_FIM_SEP_ID, vocab.special_fim_sep_id },
-
-            { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_prefix_id },
-            { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id },
-            { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
-        };
-
-        for (const auto & it : special_token_types) {
-            const std::string & key = kv(std::get<0>(it));
-            int32_t & id = std::get<1>(it);
-
-            uint32_t new_id;
-            if (!ml.get_key(std::get<0>(it), new_id, false)) {
-                continue;
-            }
-            if (new_id >= vocab.id_to_token.size()) {
-                LLAMA_LOG_WARN("%s: bad special token: '%s' = %ud, using default id %d\n",
-                    __func__, key.c_str(), new_id, id);
-            } else {
-                id = new_id;
-            }
-        }
-
-        // Handle add_bos_token and add_eos_token
-        {
-            bool temp = true;
-
-            if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
-                vocab.tokenizer_add_bos = temp;
-            }
-            if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
-                vocab.tokenizer_add_eos = temp;
-            }
-        }
-
-        // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc.
-        //
-        // TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOT_ID
-        //       for now, we apply this workaround to find the EOT token based on its text
-        if (vocab.special_eot_id == -1) {
-            for (const auto & t : vocab.token_to_id) {
-                if (
-                        // TODO: gemma "" is exported as a normal token, so the following check does not work
-                        //       need to fix convert script
-                        //vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
-                        (t.first == "<|eot_id|>" ||
-                         t.first == "<|im_end|>" ||
-                         t.first == "<|end|>" ||
-                         t.first == "" ||
-                         t.first == "<|endoftext|>"
-                        )
-                   ) {
-                    vocab.special_eot_id = t.second;
-                    break;
-                }
-            }
-        }
-
-        // find EOM token: "<|eom_id|>"
-        //
-        // TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOM_ID
-        //       for now, we apply this workaround to find the EOM token based on its text
-        if (vocab.special_eom_id == -1) {
-            const auto & t = vocab.token_to_id.find("<|eom_id|>");
-            if (t != vocab.token_to_id.end()) {
-                vocab.special_eom_id = t->second;
-            }
-        }
-
-        for (const auto & t : vocab.token_to_id) {
-            // find FIM_PRE token: "<|fim_prefix|>", "", "
", etc.
-            if (vocab.special_fim_pre_id == -1) {
-                if (false
-                        || t.first == "<|fim_prefix|>"  // Qwen
-                        || t.first == ""
-                        || t.first == ""    // Granite
-                        || t.first == "<|fim▁begin|>" // DeepSeek
-                        || t.first == "
"
-                        || t.first == "▁
"          // CodeLlama
-                        || t.first == "<|code_prefix|>" // GLM-4.5
-                        ) {
-                    vocab.special_fim_pre_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                                vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-
-            // find FIM_SUF token: "<|fim_suffix|>", "", "", etc.
-            if (vocab.special_fim_suf_id == -1) {
-                if (false
-                        || t.first == "<|fim_suffix|>" // Qwen
-                        || t.first == ""
-                        || t.first == ""   // Granite
-                        || t.first == "<|fim▁hole|>" // DeepSeek
-                        || t.first == ""
-                        || t.first == "▁"         // CodeLlama
-                        || t.first == "<|code_suffix|>" // GLM-4.5
-                        ) {
-                    vocab.special_fim_suf_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-
-            // find FIM_MID token: "<|fim_middle|>", "", "", etc.
-            if (vocab.special_fim_mid_id == -1) {
-                if (false
-                        || t.first == "<|fim_middle|>" // Qwen
-                        || t.first == ""
-                        || t.first == ""   // Granite
-                        || t.first == "<|fim▁end|>"  // DeepSeek
-                        || t.first == ""
-                        || t.first == "▁"         // CodeLlama
-                        || t.first == "<|code_middle|>" // GLM-4.5
-                        ) {
-                    vocab.special_fim_mid_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-
-            // find FIM_PAD token: "<|fim_pad|>", "", "", etc.
-            if (vocab.special_fim_pad_id == -1) {
-                if (false
-                        || t.first == "<|fim_pad|>" // Qwen
-                        || t.first == ""
-                        || t.first == ""   // Granite
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_pad_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-
-            // find FIM_REP token: "<|fim_repo|>", "", "", etc.
-            if (vocab.special_fim_rep_id == -1) {
-                if (false
-                        || t.first == "<|fim_repo|>"  // Qwen
-                        || t.first == "<|repo_name|>"
-                        || t.first == ""
-                        || t.first == ""
-                        || t.first == ""    // Granite
-                        ) {
-                    vocab.special_fim_rep_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-
-            // find FIM_SEP token: "<|file_sep|>"
-            if (vocab.special_fim_sep_id == -1) {
-                if (false
-                        || t.first == "<|file_sep|>" // Qwen
-                        ) {
-                    vocab.special_fim_sep_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
-        }
-
-    }
-
-    // build special tokens cache
-    {
-        for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
-            if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
-                vocab.cache_special_tokens.push_back(id);
-            }
-        }
-
-        std::sort(vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
-            [&] (const llama_vocab::id a, const llama_vocab::id b) {
-                return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
-            }
-        );
-
-        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t)vocab.cache_special_tokens.size());
-    }
-
-    // build token to piece cache
-    {
-        size_t size_cache = 0;
-
-        std::vector cache_token_to_piece(n_vocab);
-
-        for (uint32_t id = 0; id < n_vocab; ++id) {
-            cache_token_to_piece[id] = llama_token_to_piece(&model, id, true);
-
-            size_cache += cache_token_to_piece[id].size();
-        }
-
-        std::swap(vocab.cache_token_to_piece, cache_token_to_piece);
-
-        LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
-    }
-
-    // Handle per token attributes
-    //NOTE: Each model customizes per token attributes.
-    //NOTE: Per token attributes are missing from the GGUF file.
-    //TODO: Extract attributes from GGUF file.
-    {
-        auto _contains_any = [] (const std::string &str, const std::vector &substrs) -> bool {
-            for (auto substr : substrs) {
-                if (str.find(substr) < std::string::npos) {
-                    return true;
+	    case LLM_ARCH_DOTS1:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT,   hparams.n_layer_dense_lead);
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,  hparams.n_ff_exp);
+                ml.get_key(LLM_KV_EXPERT_SHARED_COUNT,         hparams.n_expert_shared);
+                ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE,        hparams.expert_weights_scale);
+                ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM,         hparams.expert_weights_norm, false);
+                ml.get_key(LLM_KV_EXPERT_GATING_FUNC,          hparams.expert_gating_func, false);
+                switch (hparams.n_layer) {
+                    case 62: model.type = e_model::MODEL_142B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
                 }
-            }
-            return false;
-        };
+            } break;
+        case LLM_ARCH_HUNYUAN_MOE:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,       hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        hparams.n_ff_exp);
+                ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
 
-        auto _set_tokenid_attr = [&] (const llama_vocab::id id, llama_token_attr attr, bool value) {
-            uint32_t current = vocab.id_to_token.at(id).attr;
-            current = value ? (current | attr) : (current & ~attr);
-            vocab.id_to_token[id].attr = (llama_token_attr) current;
-        };
+                switch (hparams.n_layer) {
+                    case 32: model.type = e_model::MODEL_80B_A13B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
+        case LLM_ARCH_OPENAI_MOE:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,  hparams.n_ff_exp);
+                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa);
 
-        auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
-            _set_tokenid_attr(vocab.token_to_id.at(token), attr, value);
-        };
+                //TODO OAI_MOE: SWA
+                //hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+                //hparams.set_swa_pattern(2);
 
-        std::string model_name;
-        std::string tokenizer_pre;
+                // TODO: switch (hparams.n_layer)
 
-        ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
-        ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
+            } break;
+        default: (void)0;
+    }
 
-        // model name to lowercase
-        std::transform(model_name.begin(), model_name.end(), model_name.begin(),
-            [] (const std::string::value_type x) {
-                return std::tolower(x);
-            }
-        );
+    model.ftype = ml.ftype;
 
-        // set attributes by model/tokenizer name
-        if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
-            _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
-        } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
-            for (auto id : vocab.cache_special_tokens) {
-                _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
-            }
-            for (auto token : {""}) {
-                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
-            }
-            for (auto token : {"", "", "<|endoftext|>"}) {
-                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
-            }
-        }
+    if (hparams.f_max_alibi_bias > 0.0f) {
+        hparams.use_alibi = true;
     }
+
+    hparams.rope_type = llama_rope_type(&model);
 }
 
 static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
@@ -7045,10 +4346,6 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     // hparams
     LLAMA_LOG_INFO("%s: format           = %s\n",     __func__, llama_file_version_name(ml.fver));
     LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, LLM_ARCH_NAMES.at(model.arch));
-    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, llama_model_vocab_type_name(vocab.type));
-    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, hparams.n_vocab);
-    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (int) vocab.bpe_ranks.size());
-    LLAMA_LOG_INFO("%s: vocab_only       = %d\n",     __func__, hparams.vocab_only);
 
     if (!hparams.vocab_only) {
         LLAMA_LOG_INFO("%s: n_ctx_train      = %u\n",     __func__, hparams.n_ctx_train);
@@ -7128,31 +4425,6 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     // general kv
     LLAMA_LOG_INFO("%s: general.name     = %s\n",    __func__, model.name.c_str());
 
-    // special tokens
-    if (vocab.special_bos_id    != -1) { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, vocab.special_bos_id,  vocab.id_to_token[vocab.special_bos_id].text.c_str() );  }
-    if (vocab.special_eos_id    != -1) { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, vocab.special_eos_id,  vocab.id_to_token[vocab.special_eos_id].text.c_str() );  }
-    if (vocab.special_unk_id    != -1) { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, vocab.special_unk_id,  vocab.id_to_token[vocab.special_unk_id].text.c_str() );  }
-    if (vocab.special_sep_id    != -1) { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, vocab.special_sep_id,  vocab.id_to_token[vocab.special_sep_id].text.c_str() );  }
-    if (vocab.special_pad_id    != -1) { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, vocab.special_pad_id,  vocab.id_to_token[vocab.special_pad_id].text.c_str() );  }
-    if (vocab.special_cls_id    != -1) { LLAMA_LOG_INFO( "%s: CLS token        = %d '%s'\n", __func__, vocab.special_cls_id,  vocab.id_to_token[vocab.special_cls_id].text.c_str() );  }
-    if (vocab.special_mask_id   != -1) { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
-
-    if (vocab.linefeed_id       != -1) { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, vocab.linefeed_id,       vocab.id_to_token[vocab.linefeed_id].text.c_str() );       }
- 
-    if (vocab.special_fim_pre_id != -1) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, vocab.special_fim_pre_id, vocab.id_to_token.at(vocab.special_fim_pre_id).text.c_str() ); }
-    if (vocab.special_fim_suf_id != -1) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, vocab.special_fim_suf_id, vocab.id_to_token.at(vocab.special_fim_suf_id).text.c_str() ); }
-    if (vocab.special_fim_mid_id != -1) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, vocab.special_fim_mid_id, vocab.id_to_token.at(vocab.special_fim_mid_id).text.c_str() ); }
-    if (vocab.special_fim_pad_id != -1) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, vocab.special_fim_pad_id, vocab.id_to_token.at(vocab.special_fim_pad_id).text.c_str() ); }
-    if (vocab.special_fim_rep_id != -1) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, vocab.special_fim_rep_id, vocab.id_to_token.at(vocab.special_fim_rep_id).text.c_str() ); }
-    if (vocab.special_fim_sep_id != -1) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, vocab.special_fim_sep_id, vocab.id_to_token.at(vocab.special_fim_sep_id).text.c_str() ); }
-
-    if (vocab.special_prefix_id != -1) { LLAMA_LOG_INFO( "%s: PRE token        = %d '%s'\n", __func__, vocab.special_prefix_id, vocab.id_to_token[vocab.special_prefix_id].text.c_str() ); }
-    if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token        = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); }
-    if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token        = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
-    if (vocab.special_eot_id    != -1) { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,    vocab.id_to_token[vocab.special_eot_id].text.c_str() );    }
-
-    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
-
     if (model.arch == LLM_ARCH_DEEPSEEK2) {
         LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
         LLAMA_LOG_INFO("%s: n_lora_q             = %d\n",     __func__, hparams.n_lora_q);
@@ -7170,7 +4442,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
         LLAMA_LOG_INFO("%s: n_ff_shexp       = %d\n",     __func__, hparams.n_ff_shexp);
     }
 
-    if (model.arch == LLM_ARCH_QWEN3MOE) {
+    if (model.arch == LLM_ARCH_QWEN3MOE || model.arch == LLM_ARCH_OPENAI_MOE) {
         LLAMA_LOG_INFO("%s: n_ff_exp         = %d\n",     __func__, hparams.n_ff_exp);
     }
 
@@ -7180,6 +4452,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
         LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
     }
 
+    vocab.print_info();
+
 }
 
 static void llm_prepare_mla(llama_model & model, int mla) {
@@ -9239,7 +6513,7 @@ static bool llm_load_tensors(
                     if (model.output == NULL) {
                         model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
                     }
-                    
+
                     for (int i = 0; i < n_layer; ++i) {
                         ggml_context * ctx_layer = ctx_for_layer(i);
                         ggml_context * ctx_split = ctx_for_layer_split(i);
@@ -9265,9 +6539,9 @@ static bool llm_load_tensors(
                         layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags);
 
                         // K/Q norm tensors (optional for GLM-4.5 355B variant)
-                        layer.attn_q_norm = create_tensor(ctx_layer, 
+                        layer.attn_q_norm = create_tensor(ctx_layer,
                             tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED | flags);
-                        layer.attn_k_norm = create_tensor(ctx_layer, 
+                        layer.attn_k_norm = create_tensor(ctx_layer,
                             tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED | flags);
 
                         layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags);
@@ -9285,21 +6559,21 @@ static bool llm_load_tensors(
                             // MoE branch
                             const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
 
-                            layer.ffn_gate_exps = create_tensor(ctx_split, 
+                            layer.ffn_gate_exps = create_tensor(ctx_split,
                                 tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags);
-                            layer.ffn_down_exps = create_tensor(ctx_split, 
+                            layer.ffn_down_exps = create_tensor(ctx_split,
                                 tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags);
-                            layer.ffn_up_exps = create_tensor(ctx_split, 
+                            layer.ffn_up_exps = create_tensor(ctx_split,
                                 tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags);
 
                             // Shared expert
                             if (n_expert_shared > 0) {
                                 const int64_t n_ff_shexp = n_ff_exp * n_expert_shared;
-                                layer.ffn_gate_shexp     = create_tensor(ctx_split, 
+                                layer.ffn_gate_shexp     = create_tensor(ctx_split,
                                     tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags);
-                                layer.ffn_down_shexp = create_tensor(ctx_split, 
+                                layer.ffn_down_shexp = create_tensor(ctx_split,
                                     tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags);
-                                layer.ffn_up_shexp = create_tensor(ctx_split, 
+                                layer.ffn_up_shexp = create_tensor(ctx_split,
                                     tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags);
                             }
                         } else {
@@ -9794,6 +7068,48 @@ static bool llm_load_tensors(
                         layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0);
                     }
                 } break;
+            case LLM_ARCH_OPENAI_MOE:
+                {
+                    const int64_t n_ff_exp = hparams.n_ff_exp;
+
+                    model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    model.output      = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm      = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,      "weight", i), {n_embd}, 0);
+                        layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_head * n_rot}, 0);
+                        layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_head_kv * n_rot}, 0);
+                        layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_head_kv * n_rot}, 0);
+                        layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0);
+
+                        layer.attn_sinks = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0);
+
+                        layer.ffn_gate_inp  = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {  n_embd, n_expert}, 0);
+                        layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                        layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+
+                        // bias
+                        layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_head * n_rot}, 0);
+                        layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_head_kv * n_rot}, 0);
+                        layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_head_kv * n_rot}, 0);
+                        layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
+
+                        layer.ffn_gate_inp_b  = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP,  "bias", i), {n_expert}, 0);
+                        layer.ffn_gate_exps_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0);
+                        layer.ffn_down_exps_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), {  n_embd, n_expert}, 0);
+                        layer.ffn_up_exps_b   = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP_EXPS,   "bias", i), {n_ff_exp, n_expert}, 0);
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -10011,15 +7327,16 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
             throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
         }
         try {
-            llm_load_vocab(ml, model);
+            LLM_KV kv(model.arch);
+            model.vocab.load(ml, kv);
         } catch(const std::exception & e) {
             throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
         }
 
         llm_load_print_meta(ml, model);
 
-        if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
-            model.hparams.n_vocab != model.vocab.id_to_token.size()) {
+        if (model.vocab.get_type() != LLAMA_VOCAB_TYPE_NONE &&
+            model.hparams.n_vocab != model.vocab.n_tokens()) {
             throw std::runtime_error("vocab size mismatch");
         }
 
@@ -10071,6 +7388,7 @@ enum llm_ffn_op_type {
     LLM_FFN_RELU,
     LLM_FFN_RELU_SQR,
     LLM_FFN_SWIGLU,
+    LLM_FFN_SWIGLU_OAI_MOE,
 };
 
 enum llm_ffn_gate_type {
@@ -10362,6 +7680,8 @@ static struct ggml_tensor * llm_build_ffn(
                 cur = ggml_swiglu(ctx, cur);
                 cb(cur, "ffn_swiglu", il);
             } break;
+        default:
+            GGML_ABORT("fatal error");
     }
 
     if (type_gate == LLM_FFN_PAR) {
@@ -10394,15 +7714,15 @@ static struct ggml_tensor * llm_build_ffn(
     return cur;
 }
 
-static struct ggml_tensor * llm_build_moe_ffn(
-        struct ggml_context * ctx,
-       struct llama_context & lctx,
-         struct ggml_tensor * cur,
-         struct ggml_tensor * gate_inp,
-         struct ggml_tensor * up_exps,
-         struct ggml_tensor * gate_exps,
-         struct ggml_tensor * down_exps,
-         struct ggml_tensor * exp_probs_b,
+static ggml_tensor * llm_build_moe_ffn(
+        ggml_context * ctx,
+       llama_context & lctx,
+         ggml_tensor * cur,
+         ggml_tensor * gate_inp,   ggml_tensor * gate_inp_b,
+         ggml_tensor * up_exps,    ggml_tensor * up_exps_b,
+         ggml_tensor * gate_exps,  ggml_tensor * gate_exps_b,
+         ggml_tensor * down_exps,  ggml_tensor * down_exps_b,
+         ggml_tensor * exp_probs_b,
                     int64_t   n_expert,
                     int64_t   n_expert_used,
             llm_ffn_op_type   type_op,
@@ -10411,7 +7731,7 @@ static struct ggml_tensor * llm_build_moe_ffn(
                       float   w_scale,
 llm_expert_gating_func_type   gating_op,
          const llm_build_cb & cb,
-                        int   il) {
+                        int   il, struct ggml_cgraph * graph = nullptr) {
     int64_t n_embd = cur->ne[0];
     int64_t n_tokens = cur->ne[1];
     bool weight_before_ffn = lctx.model.arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
@@ -10419,6 +7739,12 @@ llm_expert_gating_func_type   gating_op,
     ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
     cb(logits, "ffn_moe_logits", il);
 
+    if (gate_inp_b) {
+        logits = ggml_add(ctx, logits, gate_inp_b);
+        cb(logits, "ffn_moe_logits_biased", il);
+    }
+
+
     //ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
     ggml_tensor * probs = nullptr;
     switch (gating_op) {
@@ -10430,6 +7756,10 @@ llm_expert_gating_func_type   gating_op,
             {
                 probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens]
             } break;
+        case LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
+            {
+                probs = logits; // [n_expert, n_tokens]
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -10459,6 +7789,13 @@ llm_expert_gating_func_type   gating_op,
             ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
     cb(weights, "ffn_moe_weights", il);
 
+    if (gating_op == LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
+        weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
+        weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens]
+        weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
+        cb(weights, "ffn_moe_weights_softmax", il);
+    }
+
     if (norm_w) {
         weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
 
@@ -10485,9 +7822,23 @@ llm_expert_gating_func_type   gating_op,
         cb(cur, "ffn_moe_weighted", il);
     }
 
+    // For now we don't modify the fused up/gate op to include biases.
+    // Hence, if we have biases, we cannot use fmoe.
+    //
+    //bool can_use_fmoe = !up_exps_b && !gate_exps_b && (type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU);
+    bool can_use_fmoe = type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU || type_op == LLM_FFN_SWIGLU_OAI_MOE;
+
     ggml_tensor * par;
-    if (lctx.cparams.fused_moe_up_gate && up_exps->type == gate_exps->type) {
-        par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
+    if (can_use_fmoe && lctx.cparams.fused_moe_up_gate && up_exps->type == gate_exps->type) {
+        if (up_exps_b || gate_exps_b) {
+            par = ggml_moe_up_gate_ext(ctx, up_exps, gate_exps, cur, selected_experts, up_exps_b, gate_exps_b,
+                    type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU :
+                    type_op == LLM_FFN_GELU ? GGML_UNARY_OP_GELU : GGML_UNARY_OP_SWIGLU_OAI);
+        } else {
+            GGML_ASSERT(type_op != LLM_FFN_SWIGLU_OAI_MOE);
+            par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts,
+                    type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
+        }
     } else {
         ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
         cb(up, "ffn_moe_up", il);
@@ -10495,41 +7846,49 @@ llm_expert_gating_func_type   gating_op,
         ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
         cb(gate, "ffn_moe_gate", il);
 
-        // This is equivalent to the commented out code below
-        par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
+        if (graph) {
+            // So we can potentially fuse the up and gate mul_mat_id
+            ggml_build_forward_expand(graph, up);
+            ggml_build_forward_expand(graph, gate);
+        }
+
+        if (up_exps_b) {
+            up = ggml_add_id(ctx, up, up_exps_b, selected_experts);
+            cb(up, "ffn_moe_up_biased", il);
+        }
+
+        if (gate_exps_b) {
+            gate = ggml_add_id(ctx, gate, gate_exps_b, selected_experts);
+            cb(gate, "ffn_moe_gate_biased", il);
+        }
+
+        if (type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU) {
+            par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
+        } else if (type_op == LLM_FFN_SWIGLU_OAI_MOE) {
+            constexpr float alpha = 1.702f;
+            constexpr float limit = 7.0f;
+            par = ggml_swiglu_oai(ctx, gate, up, alpha, limit);
+        }
+        else {
+            GGML_ABORT("fatal error");
+        }
+
     }
     cb(par, "ffn_moe_gate_par", il);
 
     ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
     cb(experts, "ffn_moe_down", il);
 
+    if (down_exps_b) {
+        experts = ggml_add_id(ctx, experts, down_exps_b, selected_experts);
+        cb(experts, "ffn_moe_down_biased", il);
+    }
+
     if (!weight_before_ffn) {
         experts = ggml_mul(ctx, experts, weights);
         cb(cur, "ffn_moe_weighted", il);
     }
 
-//#ifdef GGML_USE_VULKAN
-//    // aggregate experts
-//    ggml_tensor * moe_out = nullptr;
-//    //ggml_tensor * first_expert = nullptr;
-//    for (int i = 0; i < n_expert_used; ++i) {
-//        ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
-//                experts->nb[2], i*experts->nb[1]);
-//
-//        if (i == 0) {
-//            moe_out = cur_expert;
-//        } else {
-//            moe_out = ggml_add(ctx, moe_out, cur_expert);
-//        }
-//    }
-//
-//    if (n_expert_used == 1) {
-//        // avoid returning a non-contiguous tensor
-//        moe_out = ggml_cont(ctx, moe_out);
-//    }
-//
-//    return moe_out;
-//#else
     if (n_expert_used == 1) {
         return ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0));
     }
@@ -10538,10 +7897,38 @@ llm_expert_gating_func_type   gating_op,
                              ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], experts->nb[1]));
     }
     return ggml_multi_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0), n_expert_used);
-//#endif
 
 }
 
+static ggml_tensor * llm_build_moe_ffn(
+        struct ggml_context * ctx,
+       struct llama_context & lctx,
+         struct ggml_tensor * cur,
+         struct ggml_tensor * gate_inp,
+         struct ggml_tensor * up_exps,
+         struct ggml_tensor * gate_exps,
+         struct ggml_tensor * down_exps,
+         struct ggml_tensor * exp_probs_b,
+                    int64_t   n_expert,
+                    int64_t   n_expert_used,
+            llm_ffn_op_type   type_op,
+                       bool   norm_w,
+                       bool   scale_w,
+                      float   w_scale,
+llm_expert_gating_func_type   gating_op,
+         const llm_build_cb & cb,
+                        int   il, struct ggml_cgraph * graph = nullptr) {
+    return llm_build_moe_ffn(ctx, lctx, cur,
+            gate_inp,   nullptr,
+            up_exps,    nullptr,
+            gate_exps,  nullptr,
+            down_exps,  nullptr,
+            exp_probs_b,
+            n_expert, n_expert_used,
+            type_op, norm_w, scale_w, w_scale,
+            gating_op, cb, il, graph);
+}
+
 static struct ggml_tensor * llm_build_kqv(
         struct ggml_context * ctx,
        struct llama_context & lctx,
@@ -10555,7 +7942,8 @@ static struct ggml_tensor * llm_build_kqv(
                     int32_t   n_kv,
                     float     kq_scale,
          const llm_build_cb & cb,
-                    int       il) {
+                    int       il,
+                ggml_tensor * sinks = nullptr) {
     const llama_model   & model   = lctx.model;
     const llama_hparams & hparams = lctx.model.hparams;
     const llama_cparams & cparams = lctx.cparams;
@@ -10602,6 +7990,7 @@ static struct ggml_tensor * llm_build_kqv(
 
         cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
                                   hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
+        ggml_flash_attn_ext_add_sinks(cur, sinks);
 
         // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
         // For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG.
@@ -10625,7 +8014,7 @@ static struct ggml_tensor * llm_build_kqv(
         cb(v, "v", il);
 
         auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024);
-        if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2]) {
+        if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2] || sinks) {
             struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
             cb(kq, "kq", il);
 
@@ -10660,6 +8049,7 @@ static struct ggml_tensor * llm_build_kqv(
                         1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
             } else {
                 kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+                ggml_soft_max_add_sinks(kq, sinks);
             }
             cb(kq, "kq_soft_max_ext", il);
 
@@ -10757,7 +8147,8 @@ static struct ggml_tensor * llm_build_kv(
                     int32_t   n_kv,
                     float     kq_scale,
          const llm_build_cb & cb,
-                    int       il) {
+                    int       il,
+                ggml_tensor * sinks = nullptr) {
     const llama_hparams & hparams = lctx.model.hparams;
     const llama_cparams & cparams = lctx.cparams;
 
@@ -10772,7 +8163,7 @@ static struct ggml_tensor * llm_build_kv(
     struct ggml_tensor * cur;
 
     cur  = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b,
-            q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
+            q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks);
     cb(cur, "kqv_out", il);
 
     return cur;
@@ -16406,24 +13797,24 @@ struct llm_build_context {
     struct ggml_cgraph * build_glm4_moe() {
         // create a new graph
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
-    
+
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    
+
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
-    
+
         // input embeddings
         inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
-    
+
         // position embeddings
         struct ggml_tensor * inp_pos = build_inp_pos();
-    
+
         // attention KV cache input
         //auto * inp_attn = build_attn_inp_kv_unified();
 
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-    
+
         // output token IDs (for last layer cropping)
         struct ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -16432,13 +13823,13 @@ struct llm_build_context {
         const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
         for (int il = 0; il < n_transformer_layers; ++il) {
             struct ggml_tensor * inpSA = inpL;
-    
+
             // Pre-attention norm
             cur = llm_build_norm(ctx0, inpL, hparams,
                                  model.layers[il].attn_norm, NULL,
                                  LLM_NORM_RMS, cb, il);
             cb(cur, "attn_norm", il);
-    
+
             // self-attention
             {
                 // Q, K, V projections
@@ -16447,24 +13838,24 @@ struct llm_build_context {
                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                 }
                 cb(Qcur, "Qcur", il);
-    
+
                 struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 if (model.layers[il].bk) {
                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                 }
                 cb(Kcur, "Kcur", il);
-    
+
                 struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 if (model.layers[il].bv) {
                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
                 }
                 cb(Vcur, "Vcur", il);
-    
+
                 // reshape for multi-head
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
                 // Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
-    
+
                 // Apply Q/K norm if available (GLM-4.5 355B variant)
                 if (model.layers[il].attn_q_norm) {
                     Qcur = llm_build_norm(ctx0, Qcur, hparams,
@@ -16478,7 +13869,7 @@ struct llm_build_context {
                                          LLM_NORM_RMS, cb, il);
                     cb(Kcur, "Kcur_normed", il);
                 }
-    
+
                 // apply RoPE
                 Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
                                      n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -16489,7 +13880,7 @@ struct llm_build_context {
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
-    
+
                 // build attention KV (no unified cache)
                 cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                                    model.layers[il].wo, NULL,
@@ -16497,7 +13888,7 @@ struct llm_build_context {
                                    n_tokens, kv_head, n_kv,
                                    1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
-    
+
             // crop output on last layer
             if (il == n_transformer_layers - 1 && inp_out_ids) {
                 // skip computing output for unused tokens
@@ -16505,17 +13896,17 @@ struct llm_build_context {
                 cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
                 inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
             }
-    
+
             // residual connection for attention output
             struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
             cb(ffn_inp, "ffn_inp", il);
-    
+
             // Post-attention norm
             cur = llm_build_norm(ctx0, ffn_inp, hparams,
                                  model.layers[il].attn_post_norm, NULL,
                                  LLM_NORM_RMS, cb, il);
             cb(cur, "post_attn_norm", il);
-    
+
             if ((uint32_t) il < hparams.n_layer_dense_lead) {
                 // dense FFN
                 cur = llm_build_ffn(ctx0, lctx, cur,
@@ -16548,29 +13939,29 @@ struct llm_build_context {
                                                 NULL,
                                                 LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                     cb(shared_out, "ffn_shexp_out", il);
-        
+
                     cur = ggml_add(ctx0, routed_out, shared_out);
                     cb(cur, "ffn_out", il);
                 }
             }
-    
+
             // residual and context vector
             cur = ggml_add(ctx0, cur, ffn_inp);
             cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
-    
+
             // prepare next layer input
             inpL = cur;
         }
 
         cur = inpL;
-    
+
         // final norm
         cur = llm_build_norm(ctx0, cur, hparams,
                              model.output_norm, NULL,
                              LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
-    
+
         // lm head
         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
@@ -17990,6 +15381,126 @@ struct llm_build_context {
 
         return gf;
     }
+
+    struct ggml_cgraph * build_openai_moe() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        struct ggml_tensor * KQ_mask     = build_inp_KQ_mask();
+        struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
+        //const int64_t n_embd_head = hparams.n_embd_head_v;
+        const float kq_scale = 1.0f / sqrtf(float(n_rot)); //float(n_embd_head));
+
+        //auto * inp_attn = build_attn_inp_kv_unified_iswa();
+
+        const int sliding_window_pattern = 2;
+
+        for (int il = 0; il < n_layer; ++il) {
+            const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+            ggml_tensor * inpSA = inpL;
+
+            struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr,
+                                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
+                                    beta_fast, beta_slow);
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
+                                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
+                                    attn_factor, beta_fast, beta_slow);
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks);
+
+                cb(cur, "attn_out", il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            cur = ffn_inp;
+            cur = llm_build_norm(ctx0, cur,  hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_post_norm", il);
+
+            // MoE branch
+            cur = llm_build_moe_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_gate_inp,  model.layers[il].ffn_gate_inp_b,
+                    model.layers[il].ffn_up_exps,   model.layers[il].ffn_up_exps_b,
+                    model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b,
+                    model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b,
+                    nullptr,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SWIGLU_OAI_MOE, false,
+                    false, 0.0,
+                    LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT,
+                    cb, il, gf);
+            cb(cur, "ffn_moe_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, nullptr, LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
 };
 
 static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) {
@@ -18086,9 +15597,9 @@ static struct ggml_cgraph * llama_build_graph(
 
     struct ggml_cgraph * result = NULL;
 
-    const llama_vocab * vocab = llama_get_vocab(&lctx);
-    llama_token bos = llama_token_bos_impl(*vocab);
-    llama_token eos = llama_token_eos_impl(*vocab);
+    const llama_vocab * vocab = &lctx.model.vocab; //llama_get_vocab(&lctx);
+    llama_token bos = vocab->token_bos();
+    llama_token eos = vocab->token_eos();
     bool is_warming_up = lctx.n_eval == 0 && (batch.n_tokens == 1 && (batch.token[0] == ((bos != -1) ? bos : eos)));
     struct llm_build_context llm(lctx, batch, cb, worst_case, is_warming_up);
 
@@ -18297,6 +15808,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_hunyuan_moe();
             } break;
+        case LLM_ARCH_OPENAI_MOE:
+            {
+                result = llm.build_openai_moe();
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -22018,7 +19533,7 @@ uint32_t llama_n_seq_max(const struct llama_context * ctx) {
 }
 
 enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
-    return model->vocab.type;
+    return model->vocab.get_type();
 }
 
 const struct llama_vocab* llama_get_model_vocab(const struct llama_model* model) {
@@ -22089,6 +19604,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_CODESHELL:
         case LLM_ARCH_DOTS1:
         case LLM_ARCH_HUNYUAN_MOE:
+        case LLM_ARCH_OPENAI_MOE:
             return LLAMA_ROPE_TYPE_NEOX;
 
         // all model arches should be listed explicitly here
@@ -22189,7 +19705,7 @@ const char* llama_model_chat_template(const struct llama_model* model, const cha
         // one-off fix for very popular models (so we are not flooded with issues)
         // do not extend this list unless absolutely necessary
         // Mistral-Small-2503 does not have built-in chat template
-        llama_vocab_pre_type pre_type = model->vocab.type_pre;
+        llama_vocab_pre_type pre_type = model->vocab.get_pre_type();
         if (!name && pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
             return "mistral-v7-tekken";
         }
@@ -23307,7 +20823,7 @@ static bool llama_state_load_file_internal(struct llama_context * ctx, const cha
 
     // restore the context state
     {
-        const size_t n_state_size_cur = file.size - file.tell();
+        const size_t n_state_size_cur = file.size() - file.tell();
 
         llama_data_read_file data_ctx(&file);
         const size_t n_read = llama_state_set_data_internal(ctx, data_ctx);
@@ -23444,7 +20960,7 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con
 
     // restore the context state
     {
-        const size_t state_size = file.size - file.tell();
+        const size_t state_size = file.size() - file.tell();
         llama_data_read_file data_ctx(&file);
         const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
         if (!nread) {
@@ -23710,101 +21226,102 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id
 //
 
 const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
-    return llama_token_get_text_impl(model->vocab, token);
+    return model->vocab.token_get_text(token);
 }
 
 float llama_token_get_score(const struct llama_model * model, llama_token token) {
-    return llama_token_get_score_impl(model->vocab, token);
+    return model->vocab.token_get_score(token);
 }
 
 enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
-    return llama_token_get_attr_impl(model->vocab, token);
+    return model->vocab.token_get_attr(token);
 }
 
 bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
-    return llama_token_is_eog_impl(model->vocab, token);
+    return model->vocab.is_eog(token);
 }
 
 bool llama_token_is_control(const struct llama_model * model, llama_token token) {
-    return llama_token_is_control_impl(model->vocab, token);
+    return model->vocab.is_control(token);
 }
 
 llama_token llama_token_bos(const struct llama_model * model) {
-    return llama_token_bos_impl(model->vocab);
+    return model->vocab.token_bos();
 }
 
 llama_token llama_token_eos(const struct llama_model * model) {
-    return llama_token_eos_impl(model->vocab);
+    return model->vocab.token_eos();
 }
 
-llama_token llama_token_cls(const struct llama_model * model) {
-    return llama_token_cls_impl(model->vocab);
-}
+// What is cls?
+//llama_token llama_token_cls(const struct llama_model * model) {
+//    return llama_token_cls_impl(model->vocab);
+//}
 
 llama_token llama_token_sep(const struct llama_model * model) {
-    return llama_token_sep_impl(model->vocab);
+    return model->vocab.token_sep();
 }
 
 llama_token llama_token_nl (const struct llama_model * model) {
-    return llama_token_nl_impl(model->vocab);
+    return model->vocab.token_nl();
 }
 
 llama_token llama_token_pad(const struct llama_model * model) {
-    return llama_token_pad_impl(model->vocab);
+    return model->vocab.token_pad();
 }
 
 int32_t llama_add_bos_token(const struct llama_model * model) {
-    return llama_add_bos_token_impl(model->vocab);
+    return model->vocab.get_add_bos();
 }
 
 int32_t llama_add_eos_token(const struct llama_model * model) {
-    return llama_add_eos_token_impl(model->vocab);
+    return model->vocab.get_add_eos();
 }
 
 llama_token llama_token_prefix(const struct llama_model * model) {
-    return llama_token_prefix_impl(model->vocab);
+    return model->vocab.token_prefix();
 }
 
 llama_token llama_token_middle(const struct llama_model * model) {
-    return llama_token_middle_impl(model->vocab);
+    return model->vocab.token_middle();
 }
 
 llama_token llama_token_suffix(const struct llama_model * model) {
-    return llama_token_suffix_impl(model->vocab);
+    return model->vocab.token_suffix();
 }
 
 llama_token llama_token_eot(const struct llama_model * model) {
-    return llama_token_eot_impl(model->vocab);
+    return model->vocab.token_eot();
 }
 
 // deprecated
 llama_token llama_token_fim_pre(const struct llama_model * model) {
-    return llama_token_fim_pre_impl(model->vocab);
+    return model->vocab.token_fim_pre();
 }
 
 // deprecated
 llama_token llama_token_fim_suf(const struct llama_model * model) {
-    return llama_token_fim_suf_impl(model->vocab);
+    return model->vocab.token_fim_suf();
 }
 
 // deprecated
 llama_token llama_token_fim_mid(const struct llama_model * model) {
-    return llama_token_fim_mid_impl(model->vocab);
+    return model->vocab.token_fim_mid();
 }
 
 // deprecated
 llama_token llama_token_fim_pad(const struct llama_model * model) {
-    return llama_token_fim_pad_impl(model->vocab);
+    return model->vocab.token_fim_pad();
 }
 
 // deprecated
 llama_token llama_token_fim_rep(const struct llama_model * model) {
-    return llama_token_fim_rep_impl(model->vocab);
+    return model->vocab.token_fim_rep();
 }
 
 // deprecated
 llama_token llama_token_fim_sep(const struct llama_model * model) {
-    return llama_token_fim_sep_impl(model->vocab);
+    return model->vocab.token_fim_sep();
 }
 
 //
@@ -23819,7 +21336,7 @@ int32_t llama_tokenize(
                      int32_t   n_tokens_max,
                         bool   add_special,
                         bool   parse_special) {
-    return llama_tokenize_impl(model->vocab, text, text_len, tokens, n_tokens_max, add_special, parse_special);
+    return model->vocab.tokenize(text, text_len, tokens, n_tokens_max, add_special, parse_special);
 }
 
 int32_t llama_token_to_piece(
@@ -23829,7 +21346,7 @@ int32_t llama_token_to_piece(
                      int32_t   length,
                      int32_t   lstrip,
                         bool   special) {
-    return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
+    return model->vocab.token_to_piece(token, buf, length, lstrip, special);
 }
 
 int32_t llama_detokenize(
@@ -23840,7 +21357,7 @@ int32_t llama_detokenize(
                      int32_t   text_len_max,
                         bool   remove_special,
                         bool   unparse_special) {
-    return llama_detokenize_impl(model->vocab, tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
+    return model->vocab.detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
 }
 
 //
@@ -23956,6 +21473,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) {
         return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
     } else if (tmpl_contains("<|im_middle|>") && tmpl_contains("<|im_end|>")) {
         return LLM_CHAT_TEMPLATE_KIMI_K2;
+    } else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) {
+        return LLM_CHAT_TEMPLATE_OPENAI_MOE;
     }
     return LLM_CHAT_TEMPLATE_UNKNOWN;
 }
@@ -24416,6 +21935,16 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|im_assistant|>assistant<|im_middle|>";
         }
+    } else if (tmpl == LLM_CHAT_TEMPLATE_OPENAI_MOE) {
+        // OpenAI MoE (based on Harmony chat template)
+        for (auto message : chat) {
+            std::string role(message->role);
+            ss << "<|start|>" << role << "<|message|>" << message->content;
+            ss << (role == "assistant" ? "<|return|>" : "<|end|>");
+        }
+        if (add_ass) {
+            ss << "<|start|>assistant";
+        }
     } else {
         // template not supported
         return -1;
diff --git a/src/unicode.cpp b/src/unicode.cpp
index c911fd262..65f366517 100644
--- a/src/unicode.cpp
+++ b/src/unicode.cpp
@@ -5,20 +5,19 @@
 #include "unicode.h"
 #include "unicode-data.h"
 
+#include 
 #include 
+#include 
 #include 
 #include 
+#include 
 #include 
 #include 
 #include 
 #include 
 #include 
-#include 
 #include 
 #include 
-#include 
-#include 
-#include 
 
 size_t unicode_len_utf8(char src) {
     const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
@@ -26,7 +25,7 @@ size_t unicode_len_utf8(char src) {
     return lookup[highbits];
 }
 
-static std::string unicode_cpts_to_utf8(const std::vector& cps) {
+static std::string unicode_cpts_to_utf8(const std::vector & cps) {
     std::string result;
     for (size_t i = 0; i < cps.size(); ++i) {
         result.append(unicode_cpt_to_utf8(cps[i]));
@@ -34,7 +33,7 @@ static std::string unicode_cpts_to_utf8(const std::vector& cps) {
     return result;
 }
 
-uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
+uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
     assert(offset < utf8.size());
     if (!(utf8[offset + 0] & 0x80)) {
         auto result = utf8[offset + 0];
@@ -45,7 +44,7 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
         throw std::invalid_argument("invalid character");
     }
     if (!(utf8[offset + 0] & 0x20)) {
-        if (offset + 1 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80)) {
+        if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) {
             throw std::invalid_argument("invalid character");
         }
         auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f);
@@ -53,7 +52,7 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
         return result;
     }
     if (!(utf8[offset + 0] & 0x10)) {
-        if (offset + 2 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80) || !((utf8[offset + 2] & 0xc0) == 0x80)) {
+        if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) {
             throw std::invalid_argument("invalid character");
         }
         auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f);
@@ -61,7 +60,7 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
         return result;
     }
     if (!(utf8[offset + 0] & 0x08)) {
-        if (offset + 3 >= utf8.size() || !((utf8[offset + 1] & 0xc0) == 0x80) || !((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) {
+        if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) {
             throw std::invalid_argument("invalid character");
         }
         auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f);
@@ -71,15 +70,15 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
     throw std::invalid_argument("failed to convert utf8 to codepoint");
 }
 
-//static std::vector unicode_cpt_to_utf16(uint32_t cp) {
+//static std::vector unicode_cpt_to_utf16(uint32_t cpt) {
 //    std::vector result;
-//    if (/* 0x0000 <= cp && */ cp <= 0xffff) {
-//        result.emplace_back(cp);
+//    if (/* 0x0000 <= cpt && */ cpt <= 0xffff) {
+//        result.emplace_back(cpt);
 //        return result;
 //    }
-//    if (0x10000 <= cp && cp <= 0x10ffff) {
-//        result.emplace_back(0xd800 | ((cp - 0x10000) >> 10));
-//        result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff));
+//    if (0x10000 <= cpt && cpt <= 0x10ffff) {
+//        result.emplace_back(0xd800 | ((cpt - 0x10000) >> 10));
+//        result.emplace_back(0xdc00 | ((cpt - 0x10000) & 0x03ff));
 //        return result;
 //    }
 //    throw std::invalid_argument("failed to convert codepoint to utf16");
@@ -120,14 +119,14 @@ uint32_t unicode_cpt_from_utf8(const std::string& utf8, size_t& offset) {
 //    return result;
 //}
 
-static std::vector unicode_cpt_flags_array() {
-    std::vector cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
+static std::vector unicode_cpt_flags_array() {
+    std::vector cpt_flags(MAX_CODEPOINTS, unicode_cpt_flags::UNDEFINED);
 
-    assert(unicode_ranges_flags.front().first == 0);
-    assert(unicode_ranges_flags.back().first == MAX_CODEPOINTS);
+    assert (unicode_ranges_flags.begin()[0].first == 0);
+    assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
     for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
-        const auto range_ini = unicode_ranges_flags[i - 1];  // codepoint_ini, flags
-        const auto range_end = unicode_ranges_flags[i];    // codepoint_end, flags
+        const auto range_ini = unicode_ranges_flags.begin()[i-1];  // codepoint_ini, flags
+        const auto range_end = unicode_ranges_flags.begin()[i];    // codepoint_end, flags
         for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
             cpt_flags[cpt] = range_ini.second;
         }
@@ -145,7 +144,7 @@ static std::vector unicode_cpt_flags_array() {
         cpt_flags[p.second].is_uppercase = true;
     }
 
-    for (auto& range : unicode_ranges_nfd) {  // start, last, nfd
+    for (auto &range : unicode_ranges_nfd) {  // start, last, nfd
         cpt_flags[range.nfd].is_nfd = true;
     }
 
@@ -200,55 +199,38 @@ static std::unordered_map unicode_utf8_to_byte_map() {
     return map;
 }
 
-static inline bool  is_valid_utf8(const std::string& str) {
-    int remaining_bytes = 0; // 当前多字节字符剩余的字节数
-    for (unsigned char c : str) {
-        if (remaining_bytes == 0) {
-            if ((c & 0x80) == 0x00) continue;          // 1字节字符
-            else if ((c & 0xE0) == 0xC0) remaining_bytes = 1; // 2字节
-            else if ((c & 0xF0) == 0xE0) remaining_bytes = 2; // 3字节
-            else if ((c & 0xF8) == 0xF0) remaining_bytes = 3; // 4字节
-            else return false; // 非法起始字节
-        }
-        else {
-            // 检查后续字节是否为10xxxxxx
-            if ((c & 0xC0) != 0x80)
-            {
-                return false;
-            }
-            remaining_bytes--;
-        }
-    }
-    return (remaining_bytes == 0); // 确保多字节字符完整
-}
-
-static inline std::wstring unicode_wstring_from_utf8(const std::string& s) {
+static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
 #if defined(__clang__)
     // disable C++17 deprecation warning for std::codecvt_utf8
 #    pragma clang diagnostic push
 #    pragma clang diagnostic ignored "-Wdeprecated-declarations"
+#elif defined(__GNUC__)
+#    pragma GCC diagnostic push
+#    pragma GCC diagnostic ignored "-Wdeprecated-declarations"
 #endif
-    bool isvalid = is_valid_utf8(s);
+
     std::wstring_convert> conv;
 
 #if defined(__clang__)
 #    pragma clang diagnostic pop
+#elif defined(__GNUC__)
+#    pragma GCC diagnostic pop
 #endif
 
     return conv.from_bytes(s);
 }
 
-static std::vector unicode_byte_encoding_process(const std::vector& bpe_words) {
+static std::vector unicode_byte_encoding_process(const std::vector & bpe_words) {
     std::vector bpe_encoded_words;
-    for (const auto& word : bpe_words) {
+    for (const auto & word : bpe_words) {
         std::string text_utf;
-        auto utf_word = unicode_cpts_from_utf8(word);
+        auto utf_word =  unicode_cpts_from_utf8(word);
         for (size_t i = 0; i < utf_word.size(); ++i) {
             text_utf += unicode_cpt_to_utf8(utf_word[i]);
         }
 
         std::string encoded_token;
-        for (char& c : text_utf) {
+        for (char & c : text_utf) {
             encoded_token += unicode_byte_to_utf8(c);
         }
         bpe_encoded_words.emplace_back(encoded_token);
@@ -257,7 +239,7 @@ static std::vector unicode_byte_encoding_process(const std::vector<
 }
 
 // GPT2 system regex:  's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
-static std::vector unicode_regex_split_custom_gpt2(const std::string& text, const std::vector& offsets) {
+static std::vector unicode_regex_split_custom_gpt2(const std::string & text, const std::vector & offsets) {
     std::vector bpe_offsets; // store the offset of each word
     bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
 
@@ -271,16 +253,16 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string& te
         start = offset_end;
 
         static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
-        auto _get_cpt = [&](const size_t pos) -> uint32_t {
+        auto _get_cpt = [&] (const size_t pos) -> uint32_t {
             return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
         };
 
-        auto _get_flags = [&](const size_t pos) -> codepoint_flags {
-            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
+        auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
         };
 
         size_t _prev_end = offset_ini;
-        auto _add_token = [&](const size_t end) -> size_t {
+        auto _add_token = [&] (const size_t end) -> size_t {
             assert(_prev_end <= end && end <= offset_end);
             size_t len = end - _prev_end;
             if (len > 0) {
@@ -296,29 +278,29 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string& te
             return len;
         };
 
-        for (size_t pos = offset_ini; pos < offset_end; /*pos++*/) {
+        for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
             const uint32_t cpt = _get_cpt(pos);
             const auto flags = _get_flags(pos);
 
             // regex: 's|'t|'re|'ve|'m|'ll|'d
-            if (cpt == '\'' && pos + 1 < offset_end) {
-                uint32_t cpt_next = _get_cpt(pos + 1);
+            if (cpt == '\'' && pos+1 < offset_end) {
+                uint32_t cpt_next = _get_cpt(pos+1);
                 if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
-                    pos += _add_token(pos + 2);
+                    pos += _add_token(pos+2);
                     continue;
                 }
-                if (pos + 2 < offset_end) {
-                    uint32_t cpt_next_next = _get_cpt(pos + 2);
+                if (pos+2 < offset_end) {
+                    uint32_t cpt_next_next = _get_cpt(pos+2);
                     if ((cpt_next == 'r' && cpt_next_next == 'e') ||
                         (cpt_next == 'v' && cpt_next_next == 'e') ||
                         (cpt_next == 'l' && cpt_next_next == 'l')) {
-                        pos += _add_token(pos + 3);
+                        pos += _add_token(pos+3);
                         continue;
                     }
                 }
             }
 
-            auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags);
+            auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
             // regex: ?\p{L}+
             if (flags2.is_letter) {
                 pos += (cpt == ' ');
@@ -348,12 +330,12 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string& te
             }
 
             size_t num_whitespaces = 0;
-            while (_get_flags(pos + num_whitespaces).is_whitespace) {
+            while (_get_flags(pos+num_whitespaces).is_whitespace) {
                 num_whitespaces++;
             }
 
             // regex: \s+(?!\S)
-            if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) {
+            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
                 pos += num_whitespaces - 1;
                 _add_token(pos);
                 continue;
@@ -374,11 +356,10 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string& te
     return bpe_offsets;
 }
 
-// K2 system regex patterns (from tokenization_kimi.py):
-// [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+
-static std::vector unicode_regex_split_custom_kimi_k2(const std::string & text, const std::vector & offsets) {
-    std::vector bpe_offsets;
-    bpe_offsets.reserve(offsets.size());
+// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
+static std::vector unicode_regex_split_custom_llama3(const std::string & text, const std::vector & offsets) {
+    std::vector bpe_offsets; // store the offset of each word
+    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
 
     const auto cpts = unicode_cpts_from_utf8(text);
 
@@ -394,8 +375,8 @@ static std::vector unicode_regex_split_custom_kimi_k2(const std::string
             return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
         };
 
-        auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
-            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
+        auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
         };
 
         size_t _prev_end = offset_ini;
@@ -406,6 +387,12 @@ static std::vector unicode_regex_split_custom_kimi_k2(const std::string
                 bpe_offsets.push_back(len);
             }
             _prev_end = end;
+            //if (len > 0) {
+            //    std::string s = "";
+            //    for(size_t p = end-len; p < end; p++)
+            //        s += unicode_cpt_to_utf8(cpts[p]);
+            //    printf(">>> '%s'\n", s.c_str());
+            //}
             return len;
         };
 
@@ -413,75 +400,41 @@ static std::vector unicode_regex_split_custom_kimi_k2(const std::string
             const uint32_t cpt = _get_cpt(pos);
             const auto flags = _get_flags(pos);
 
-            // Pattern 1: [\p{Han}]+ (Chinese characters)
-            if (unicode_cpt_is_han(cpt)) {
-                while (unicode_cpt_is_han(_get_cpt(pos))) {
-                    pos++;
+            // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
+            if (cpt == '\'' && pos+1 < offset_end) {
+                uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
+                if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
+                    pos += _add_token(pos+2);
+                    continue;
                 }
-                _add_token(pos);
-                continue;
-            }
-
-            // Pattern 2 & 3: Letter words excluding Han characters with optional contractions
-            // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?:'s|'t|'re|'ve|'m|'ll|'d)?
-            // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?:'s|'t|'re|'ve|'m|'ll|'d)?
-            // Check if current char is a letter OR if current char could be a leading char and next char is a letter
-            bool is_letter_pattern = (flags.is_letter && !unicode_cpt_is_han(cpt)) ||
-                                     (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number) &&
-                                      _get_flags(pos + 1).is_letter && !unicode_cpt_is_han(_get_cpt(pos + 1)));
-
-            if (is_letter_pattern) {
-                // Handle optional leading non-letter/non-number character
-                bool has_leading_char = false;
-                if (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number)) {
-                    has_leading_char = true;
-                    pos++;
+                if (pos+2 < offset_end) {
+                    uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
+                    if ((cpt_next == 'r' && cpt_next_next == 'e') ||
+                        (cpt_next == 'v' && cpt_next_next == 'e') ||
+                        (cpt_next == 'l' && cpt_next_next == 'l')) {
+                        pos += _add_token(pos+3);
+                        continue;
+                    }
                 }
+            }
 
-                // Match letter sequence (excluding Han characters)
-                bool has_letters = false;
-                while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
-                    has_letters = true;
+            // regex: [^\r\n\p{L}\p{N}]?\p{L}+
+            if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
+                if (flags.is_letter || _get_flags(pos+1).is_letter) {  // one or more letters
                     pos++;
-                }
-
-                // Only proceed if we found letters (after potentially skipping leading char)
-                if (has_letters || (!has_leading_char && _get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos)))) {
-                    if (!has_letters) pos++; // consume the first letter if we didn't already
-
-                    // Continue consuming letters
-                    while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
+                    while (_get_flags(pos).is_letter) {
                         pos++;
                     }
-
-                    // Check for optional contractions (?:'s|'t|'re|'ve|'m|'ll|'d)
-                    if (_get_cpt(pos) == '\'' && pos + 1 < offset_end) {
-                        uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1));
-                        if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
-                            pos += 2;
-                        } else if (pos + 2 < offset_end) {
-                            uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2));
-                            if ((cpt_next == 'r' && cpt_next_next == 'e') ||
-                                (cpt_next == 'v' && cpt_next_next == 'e') ||
-                                (cpt_next == 'l' && cpt_next_next == 'l')) {
-                                pos += 3;
-                            }
-                        }
-                    }
-
                     _add_token(pos);
                     continue;
-                } else if (has_leading_char) {
-                    // We consumed a leading char but found no letters, backtrack
-                    pos--;
                 }
             }
 
-            // Pattern 4: \p{N}{1,3} (numbers 1-3 digits)
+            // regex: \p{N}{1,3}
             if (flags.is_number) {
                 size_t ini = pos;
                 while (_get_flags(pos).is_number) {
-                    if (++pos - ini >= 3) {
+                    if (++pos - ini >= 3 ) {
                         _add_token(pos);
                         ini = pos;
                     }
@@ -490,14 +443,13 @@ static std::vector unicode_regex_split_custom_kimi_k2(const std::string
                 continue;
             }
 
-            // Pattern 5:  ?[^\s\p{L}\p{N}]+[\r\n]* (optional space + non-word chars + optional newlines)
-            auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags);
-            if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
+            // regex: ?[^\s\p{L}\p{N}]+[\r\n]*
+            auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
+            if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
                 pos += (cpt == ' ');
-                while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
+                while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
                     flags2 = _get_flags(++pos);
                 }
-                // Match optional [\r\n]*
                 uint32_t cpt2 = _get_cpt(pos);
                 while (cpt2 == '\r' || cpt2 == '\n') {
                     cpt2 = _get_cpt(++pos);
@@ -506,39 +458,38 @@ static std::vector unicode_regex_split_custom_kimi_k2(const std::string
                 continue;
             }
 
-            // Count whitespace characters
             size_t num_whitespaces = 0;
             size_t last_end_r_or_n = 0;
-            while (_get_flags(pos + num_whitespaces).is_whitespace) {
-                uint32_t cpt2 = _get_cpt(pos + num_whitespaces);
+            while (_get_flags(pos+num_whitespaces).is_whitespace) {
+                uint32_t cpt2 = _get_cpt(pos+num_whitespaces);
                 if (cpt2 == '\r' || cpt2 == '\n') {
                     last_end_r_or_n = pos + num_whitespaces + 1;
                 }
                 num_whitespaces++;
             }
 
-            // Pattern 6: \s*[\r\n]+ (whitespace with newlines)
+            // regex: \s*[\r\n]+
             if (last_end_r_or_n > 0) {
                 pos = last_end_r_or_n;
                 _add_token(pos);
                 continue;
             }
 
-            // Pattern 7: \s+(?!\S) (trailing whitespace)
-            if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) {
+            // regex: \s+(?!\S)
+            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
                 pos += num_whitespaces - 1;
                 _add_token(pos);
                 continue;
             }
 
-            // Pattern 8: \s+ (general whitespace)
+            // regex: \s+
             if (num_whitespaces > 0) {
                 pos += num_whitespaces;
                 _add_token(pos);
                 continue;
             }
 
-            // No matches - consume single character
+            // no matches
             _add_token(++pos);
         }
     }
@@ -546,10 +497,71 @@ static std::vector unicode_regex_split_custom_kimi_k2(const std::string
     return bpe_offsets;
 }
 
-// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
-static std::vector unicode_regex_split_custom_llama3(const std::string& text, const std::vector& offsets) {
+// use std::wregex to split the text
+static std::vector unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector & offsets) {
+    std::wregex expr(regex_expr);
+    std::vector bpe_offsets; // store the offset of each word
+    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
+    size_t start = 0;
+    for (auto offset : offsets) {
+        std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
+        std::wcregex_iterator end;
+
+        int64_t start_idx = 0;
+        while (it != end) {
+            std::wcmatch match = *it;
+            if (match.position() > start_idx) {
+                bpe_offsets.emplace_back(match.position() - start_idx);
+            }
+            bpe_offsets.emplace_back(match.length());
+            start_idx = match.position() + match.length();
+            ++it;
+        }
+
+        if (start_idx < (int64_t) offset) {
+            bpe_offsets.emplace_back(offset - start_idx);
+        }
+        start += offset;
+    }
+
+    return bpe_offsets;
+}
+
+// use std::regex to split the text
+static std::vector unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector & offsets) {
+    std::regex expr(regex_expr);
     std::vector bpe_offsets; // store the offset of each word
     bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
+    size_t start = 0;
+    for (auto offset : offsets) {
+        std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
+        std::cregex_iterator end;
+
+        int64_t start_idx = 0;
+        while (it != end) {
+            std::cmatch match = *it;
+            if (match.position() > start_idx) {
+                bpe_offsets.emplace_back(match.position() - start_idx);
+            }
+            bpe_offsets.emplace_back(match.length());
+            start_idx = match.position() + match.length();
+            ++it;
+        }
+
+        if (start_idx < (int64_t) offset) {
+            bpe_offsets.emplace_back(offset - start_idx);
+        }
+        start += offset;
+    }
+
+    return bpe_offsets;
+}
+
+// K2 system regex patterns (from tokenization_kimi.py):
+// [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+
+static std::vector unicode_regex_split_custom_kimi_k2(const std::string & text, const std::vector & offsets) {
+    std::vector bpe_offsets;
+    bpe_offsets.reserve(offsets.size());
 
     const auto cpts = unicode_cpts_from_utf8(text);
 
@@ -561,66 +573,94 @@ static std::vector unicode_regex_split_custom_llama3(const std::string&
         start = offset_end;
 
         static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
-        auto _get_cpt = [&](const size_t pos) -> uint32_t {
+        auto _get_cpt = [&] (const size_t pos) -> uint32_t {
             return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
         };
 
-        auto _get_flags = [&](const size_t pos) -> codepoint_flags {
-            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
+        auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
         };
 
         size_t _prev_end = offset_ini;
-        auto _add_token = [&](const size_t end) -> size_t {
+        auto _add_token = [&] (const size_t end) -> size_t {
             assert(_prev_end <= end && end <= offset_end);
             size_t len = end - _prev_end;
             if (len > 0) {
                 bpe_offsets.push_back(len);
             }
             _prev_end = end;
-            //if (len > 0) {
-            //    std::string s = "";
-            //    for(size_t p = end-len; p < end; p++)
-            //        s += unicode_cpt_to_utf8(cpts[p]);
-            //    printf(">>> '%s'\n", s.c_str());
-            //}
             return len;
         };
 
-        for (size_t pos = offset_ini; pos < offset_end; /*pos++*/) {
+        for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
             const uint32_t cpt = _get_cpt(pos);
             const auto flags = _get_flags(pos);
 
-            // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
-            if (cpt == '\'' && pos + 1 < offset_end) {
-                uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1));
-                if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
-                    pos += _add_token(pos + 2);
-                    continue;
-                }
-                if (pos + 2 < offset_end) {
-                    uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2));
-                    if ((cpt_next == 'r' && cpt_next_next == 'e') ||
-                        (cpt_next == 'v' && cpt_next_next == 'e') ||
-                        (cpt_next == 'l' && cpt_next_next == 'l')) {
-                        pos += _add_token(pos + 3);
-                        continue;
-                    }
+            // Pattern 1: [\p{Han}]+ (Chinese characters)
+            if (unicode_cpt_is_han(cpt)) {
+                while (unicode_cpt_is_han(_get_cpt(pos))) {
+                    pos++;
                 }
+                _add_token(pos);
+                continue;
             }
 
-            // regex: [^\r\n\p{L}\p{N}]?\p{L}+
-            if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
-                if (flags.is_letter || _get_flags(pos + 1).is_letter) {  // one or more letters
+            // Pattern 2 & 3: Letter words excluding Han characters with optional contractions
+            // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?:'s|'t|'re|'ve|'m|'ll|'d)?
+            // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?:'s|'t|'re|'ve|'m|'ll|'d)?
+            // Check if current char is a letter OR if current char could be a leading char and next char is a letter
+            bool is_letter_pattern = (flags.is_letter && !unicode_cpt_is_han(cpt)) ||
+                                     (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number) &&
+                                      _get_flags(pos + 1).is_letter && !unicode_cpt_is_han(_get_cpt(pos + 1)));
+
+            if (is_letter_pattern) {
+                // Handle optional leading non-letter/non-number character
+                bool has_leading_char = false;
+                if (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number)) {
+                    has_leading_char = true;
                     pos++;
-                    while (_get_flags(pos).is_letter) {
+                }
+
+                // Match letter sequence (excluding Han characters)
+                bool has_letters = false;
+                while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
+                    has_letters = true;
+                    pos++;
+                }
+
+                // Only proceed if we found letters (after potentially skipping leading char)
+                if (has_letters || (!has_leading_char && _get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos)))) {
+                    if (!has_letters) pos++; // consume the first letter if we didn't already
+
+                    // Continue consuming letters
+                    while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
                         pos++;
                     }
+
+                    // Check for optional contractions (?:'s|'t|'re|'ve|'m|'ll|'d)
+                    if (_get_cpt(pos) == '\'' && pos + 1 < offset_end) {
+                        uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1));
+                        if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
+                            pos += 2;
+                        } else if (pos + 2 < offset_end) {
+                            uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2));
+                            if ((cpt_next == 'r' && cpt_next_next == 'e') ||
+                                (cpt_next == 'v' && cpt_next_next == 'e') ||
+                                (cpt_next == 'l' && cpt_next_next == 'l')) {
+                                pos += 3;
+                            }
+                        }
+                    }
+
                     _add_token(pos);
                     continue;
+                } else if (has_leading_char) {
+                    // We consumed a leading char but found no letters, backtrack
+                    pos--;
                 }
             }
 
-            // regex: \p{N}{1,3}
+            // Pattern 4: \p{N}{1,3} (numbers 1-3 digits)
             if (flags.is_number) {
                 size_t ini = pos;
                 while (_get_flags(pos).is_number) {
@@ -633,13 +673,14 @@ static std::vector unicode_regex_split_custom_llama3(const std::string&
                 continue;
             }
 
-            // regex: ?[^\s\p{L}\p{N}]+[\r\n]*
+            // Pattern 5:  ?[^\s\p{L}\p{N}]+[\r\n]* (optional space + non-word chars + optional newlines)
             auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags);
-            if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
+            if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
                 pos += (cpt == ' ');
-                while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
+                while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
                     flags2 = _get_flags(++pos);
                 }
+                // Match optional [\r\n]*
                 uint32_t cpt2 = _get_cpt(pos);
                 while (cpt2 == '\r' || cpt2 == '\n') {
                     cpt2 = _get_cpt(++pos);
@@ -648,6 +689,7 @@ static std::vector unicode_regex_split_custom_llama3(const std::string&
                 continue;
             }
 
+            // Count whitespace characters
             size_t num_whitespaces = 0;
             size_t last_end_r_or_n = 0;
             while (_get_flags(pos + num_whitespaces).is_whitespace) {
@@ -658,28 +700,28 @@ static std::vector unicode_regex_split_custom_llama3(const std::string&
                 num_whitespaces++;
             }
 
-            // regex: \s*[\r\n]+
+            // Pattern 6: \s*[\r\n]+ (whitespace with newlines)
             if (last_end_r_or_n > 0) {
                 pos = last_end_r_or_n;
                 _add_token(pos);
                 continue;
             }
 
-            // regex: \s+(?!\S)
+            // Pattern 7: \s+(?!\S) (trailing whitespace)
             if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) {
                 pos += num_whitespaces - 1;
                 _add_token(pos);
                 continue;
             }
 
-            // regex: \s+
+            // Pattern 8: \s+ (general whitespace)
             if (num_whitespaces > 0) {
                 pos += num_whitespaces;
                 _add_token(pos);
                 continue;
             }
 
-            // no matches
+            // No matches - consume single character
             _add_token(++pos);
         }
     }
@@ -687,79 +729,17 @@ static std::vector unicode_regex_split_custom_llama3(const std::string&
     return bpe_offsets;
 }
 
-// use std::wregex to split the text
-static std::vector unicode_regex_split_stl(const std::wstring& wtext, const std::wstring& regex_expr, const std::vector& offsets) {
-    std::wregex expr(regex_expr);
-    std::vector bpe_offsets; // store the offset of each word
-    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
-    size_t start = 0;
-    for (auto offset : offsets) {
-        std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
-        std::wcregex_iterator end;
-
-        int64_t start_idx = 0;
-        while (it != end) {
-            std::wcmatch match = *it;
-            if (match.position() > start_idx) {
-                bpe_offsets.emplace_back(match.position() - start_idx);
-            }
-            bpe_offsets.emplace_back(match.length());
-            start_idx = match.position() + match.length();
-            ++it;
-        }
-
-        if (start_idx < (int64_t)offset) {
-            bpe_offsets.emplace_back(offset - start_idx);
-        }
-        start += offset;
-    }
-
-    return bpe_offsets;
-}
-
-// use std::regex to split the text
-static std::vector unicode_regex_split_stl(const std::string& text, const std::string& regex_expr, const std::vector& offsets) {
-    std::regex expr(regex_expr);
-    std::vector bpe_offsets; // store the offset of each word
-    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
-    size_t start = 0;
-    for (auto offset : offsets) {
-        std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
-        std::cregex_iterator end;
-
-        int64_t start_idx = 0;
-        while (it != end) {
-            std::cmatch match = *it;
-            if (match.position() > start_idx) {
-                bpe_offsets.emplace_back(match.position() - start_idx);
-            }
-            bpe_offsets.emplace_back(match.length());
-            start_idx = match.position() + match.length();
-            ++it;
-        }
-
-        if (start_idx < (int64_t)offset) {
-            bpe_offsets.emplace_back(offset - start_idx);
-        }
-        start += offset;
-    }
-
-    return bpe_offsets;
-}
-
-static std::vector unicode_regex_split_custom(const std::string& text, const std::string& regex_expr, const std::vector& offsets) {
+static std::vector unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector & offsets) {
     std::vector bpe_offsets;
 
     if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
         bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
-    }
-    else if (
-        regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
-        regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
+    } else if (
+            regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
+            regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
 
         bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
-    }
-    else if (regex_expr == "\\p{Han}+") {
+    } else if (regex_expr == "\\p{Han}+") {
         // K2's first pattern - handle all K2 patterns together
         bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);
     }
@@ -771,71 +751,100 @@ static std::vector unicode_regex_split_custom(const std::string& text, c
 // interface
 //
 
-std::string unicode_cpt_to_utf8(uint32_t cp) {
+std::string unicode_cpt_to_utf8(uint32_t cpt) {
     std::string result;
 
-    if (/* 0x00 <= cp && */ cp <= 0x7f) {
-        result.push_back(cp);
+    if (/* 0x00 <= cpt && */ cpt <= 0x7f) {
+        result.push_back(cpt);
         return result;
     }
-    if (0x80 <= cp && cp <= 0x7ff) {
-        result.push_back(0xc0 | ((cp >> 6) & 0x1f));
-        result.push_back(0x80 | (cp & 0x3f));
+    if (0x80 <= cpt && cpt <= 0x7ff) {
+        result.push_back(0xc0 | ((cpt >> 6) & 0x1f));
+        result.push_back(0x80 | (cpt & 0x3f));
         return result;
     }
-    if (0x800 <= cp && cp <= 0xffff) {
-        result.push_back(0xe0 | ((cp >> 12) & 0x0f));
-        result.push_back(0x80 | ((cp >> 6) & 0x3f));
-        result.push_back(0x80 | (cp & 0x3f));
+    if (0x800 <= cpt && cpt <= 0xffff) {
+        result.push_back(0xe0 | ((cpt >> 12) & 0x0f));
+        result.push_back(0x80 | ((cpt >> 6) & 0x3f));
+        result.push_back(0x80 | (cpt & 0x3f));
         return result;
     }
-    if (0x10000 <= cp && cp <= 0x10ffff) {
-        result.push_back(0xf0 | ((cp >> 18) & 0x07));
-        result.push_back(0x80 | ((cp >> 12) & 0x3f));
-        result.push_back(0x80 | ((cp >> 6) & 0x3f));
-        result.push_back(0x80 | (cp & 0x3f));
+    if (0x10000 <= cpt && cpt <= 0x10ffff) {
+        result.push_back(0xf0 | ((cpt >> 18) & 0x07));
+        result.push_back(0x80 | ((cpt >> 12) & 0x3f));
+        result.push_back(0x80 | ((cpt >> 6) & 0x3f));
+        result.push_back(0x80 | (cpt & 0x3f));
         return result;
     }
 
     throw std::invalid_argument("invalid codepoint");
 }
 
-std::vector unicode_cpts_normalize_nfd(const std::vector& cpts) {
-    auto comp = [](const uint32_t cpt, const range_nfd& range) {
+std::vector unicode_cpts_normalize_nfd(const std::vector & cpts) {
+    auto comp = [] (const uint32_t cpt, const range_nfd & range) {
         return cpt < range.first;
     };
     std::vector result(cpts.size());
     for (size_t i = 0; i < cpts.size(); ++i) {
         const uint32_t cpt = cpts[i];
-        auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1;
+        auto it = std::upper_bound(unicode_ranges_nfd.begin(), unicode_ranges_nfd.end(), cpt, comp) - 1;
         result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
     }
     return result;
 }
 
-std::vector unicode_cpts_from_utf8(const std::string& utf8) {
+std::vector unicode_cpts_from_utf8(const std::string & utf8) {
     std::vector result;
     result.reserve(utf8.size());
     size_t offset = 0;
     while (offset < utf8.size()) {
-        result.push_back(unicode_cpt_from_utf8(utf8, offset));
+        try {
+            result.push_back(unicode_cpt_from_utf8(utf8, offset));
+        }
+        catch (const std::invalid_argument & /*ex*/) {
+            // Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
+            ++offset;
+            result.emplace_back(0xFFFD); // replacement character
+        }
     }
     return result;
 }
 
-codepoint_flags unicode_cpt_flags(const uint32_t cp) {
-    static const codepoint_flags undef(codepoint_flags::UNDEFINED);
+unicode_cpt_flags unicode_cpt_flags_from_cpt(const uint32_t cpt) {
+    static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
     static const auto cpt_flags = unicode_cpt_flags_array();
-    return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
+    return cpt < cpt_flags.size() ? cpt_flags[cpt] : undef;
 }
 
-codepoint_flags unicode_cpt_flags(const std::string& utf8) {
-    static const codepoint_flags undef(codepoint_flags::UNDEFINED);
+unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8) {
+    static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
     if (utf8.empty()) {
         return undef;  // undefined
     }
     size_t offset = 0;
-    return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
+    return unicode_cpt_flags_from_cpt(unicode_cpt_from_utf8(utf8, offset));
+}
+
+std::string unicode_byte_to_utf8(uint8_t byte) {
+    static std::unordered_map map = unicode_byte_to_utf8_map();
+    return map.at(byte);
+}
+
+uint8_t unicode_utf8_to_byte(const std::string & utf8) {
+    static std::unordered_map map = unicode_utf8_to_byte_map();
+    return map.at(utf8);
+}
+
+uint32_t unicode_tolower(uint32_t cpt) {
+    // binary search
+    auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cpt,
+        [](const std::pair & pair, uint32_t value) {
+            return pair.first < value;
+        });
+    if (it != unicode_map_lowercase.end() && it->first == cpt) {
+        return it->second;
+    }
+    return cpt;  // Return the original code point if no lowercase mapping is found
 }
 
 bool unicode_cpt_is_han(uint32_t cpt) {
@@ -870,53 +879,37 @@ bool unicode_cpt_is_han(uint32_t cpt) {
     return false;
 }
 
-std::string unicode_byte_to_utf8(uint8_t byte) {
-    static std::unordered_map map = unicode_byte_to_utf8_map();
-    return map.at(byte);
-}
-
-uint8_t unicode_utf8_to_byte(const std::string& utf8) {
-    static std::unordered_map map = unicode_utf8_to_byte_map();
-    return map.at(utf8);
-}
-
-uint32_t unicode_tolower(uint32_t cp) {
-    auto it = unicode_map_lowercase.find(cp);
-    return it == unicode_map_lowercase.end() ? cp : it->second;
-}
-
-std::vector unicode_regex_split(const std::string& text, const std::vector& regex_exprs) {
+std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) {
     // unicode categories
     static const std::map k_ucat_enum = {
-        { "\\p{N}", codepoint_flags::NUMBER },
-        { "\\p{L}", codepoint_flags::LETTER },
-        { "\\p{P}", codepoint_flags::PUNCTUATION },
-        { "\\p{M}", codepoint_flags::ACCENT_MARK },
-        { "\\p{S}", codepoint_flags::SYMBOL },
+        { "\\p{N}", unicode_cpt_flags::NUMBER },
+        { "\\p{L}", unicode_cpt_flags::LETTER },
+        { "\\p{P}", unicode_cpt_flags::PUNCTUATION },
+        { "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
+        { "\\p{S}", unicode_cpt_flags::SYMBOL },
     };
 
     static const std::map k_ucat_cpt = {
-        { codepoint_flags::NUMBER,        0xD1 },
-        { codepoint_flags::LETTER,        0xD2 },
-        { codepoint_flags::PUNCTUATION,   0xD3 },
-        { codepoint_flags::ACCENT_MARK,   0xD4 },
-        { codepoint_flags::SYMBOL,        0xD5 },
-
+        { unicode_cpt_flags::NUMBER,      0xD1 },
+        { unicode_cpt_flags::LETTER,      0xD2 },
+        { unicode_cpt_flags::PUNCTUATION, 0xD3 },
+        { unicode_cpt_flags::ACCENT_MARK, 0xD4 },
+        { unicode_cpt_flags::SYMBOL,      0xD5 },
     };
 
     static const std::map k_ucat_map = {
-        { codepoint_flags::NUMBER,        "\x30-\x39" }, // 0-9
-        { codepoint_flags::LETTER,        "\x41-\x5A\x61-\x7A" }, // A-Za-z
-        { codepoint_flags::PUNCTUATION,   "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}i
-        { codepoint_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
-        { codepoint_flags::SYMBOL,      "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
+        { unicode_cpt_flags::NUMBER,      "\x30-\x39" }, // 0-9
+        { unicode_cpt_flags::LETTER,      "\x41-\x5A\x61-\x7A" }, // A-Za-z
+        { unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
+        { unicode_cpt_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
+        { unicode_cpt_flags::SYMBOL,      "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
     };
 
     // compute collapsed codepoints only if needed by at least one regex
     bool need_collapse = false;
-    for (auto& regex_expr : regex_exprs) {
+    for (const auto & regex_expr : regex_exprs) {
         // search for unicode categories
-        for (const auto& ucat : k_ucat_enum) {
+        for (const auto & ucat : k_ucat_enum) {
             if (std::string::npos != regex_expr.find(ucat.first)) {
                 need_collapse = true;
                 break;
@@ -927,7 +920,7 @@ std::vector unicode_regex_split(const std::string& text, const std:
     const auto cpts = unicode_cpts_from_utf8(text);
 
     // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
-    // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
+    // ref: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2081479935
     std::string text_collapsed;
     if (need_collapse) {
         // collapse all unicode categories
@@ -940,25 +933,23 @@ std::vector unicode_regex_split(const std::string& text, const std:
                 continue;
             }
 
-            const auto flags = unicode_cpt_flags(cpts[i]);
+            const auto flags = unicode_cpt_flags_from_cpt(cpts[i]);
 
             if (flags.is_whitespace) {
                 //NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
                 //text_collapsed[i] = (char) 0x85;  //  as whitespace fallback
-                text_collapsed[i] = (char)0x0B;    //  as whitespace fallback
-            }
-            else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
+                text_collapsed[i] = (char) 0x0B;    //  as whitespace fallback
+            } else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
                 text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
-            }
-            else {
-                text_collapsed[i] = (char)0xD0; // fallback
+            } else {
+                text_collapsed[i] = (char) 0xD0; // fallback
             }
         }
     }
 
     std::vector bpe_offsets = { cpts.size() };
 
-    for (auto& regex_expr : regex_exprs) {
+    for (const auto & regex_expr : regex_exprs) {
         // first, see if we have an efficient custom regex implementation
         auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
 
@@ -972,7 +963,7 @@ std::vector unicode_regex_split(const std::string& text, const std:
             // if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
             // with the corresponding collapsed representation
             bool use_collapsed = false;
-            for (auto& ucat : k_ucat_enum) {
+            for (const auto & ucat : k_ucat_enum) {
                 if (std::string::npos != regex_expr.find(ucat.first)) {
                     use_collapsed = true;
                     break;
@@ -1031,15 +1022,14 @@ std::vector unicode_regex_split(const std::string& text, const std:
                 //printf("text_collapsed: %s\n", text_collapsed.c_str());
                 //printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
                 bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
-            }
-            else {
+            } else {
                 // no unicode category used, we can use std::wregex directly
                 const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
 
                 // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
                 std::wstring wtext(cpts.begin(), cpts.end());
                 for (size_t i = 0; i < wtext.size(); ++i) {
-                    if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
+                    if (wtext[i] > 0x7F && unicode_cpt_flags_from_cpt(wtext[i]).is_whitespace) {
                         wtext[i] = 0x0B;
                     }
                 }
@@ -1048,8 +1038,7 @@ std::vector unicode_regex_split(const std::string& text, const std:
                 //printf("regex_expr: %s\n", regex_expr.c_str());
                 bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
             }
-        }
-        catch (std::regex_error& e) {
+        } catch (std::regex_error & e) {
             fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
             fprintf(stderr, "Regex error: %s\n", e.what());
             throw std::runtime_error("Failed to process regex");
@@ -1060,7 +1049,7 @@ std::vector unicode_regex_split(const std::string& text, const std:
     bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size
 
     size_t start = 0;
-    for (size_t& offset : bpe_offsets) {
+    for (size_t & offset : bpe_offsets) {
         bpe_words.emplace_back();
         for (size_t i = start; i < start + offset; ++i) {
             bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);
diff --git a/src/unicode.h b/src/unicode.h
index 48940239c..0a5fa2a78 100644
--- a/src/unicode.h
+++ b/src/unicode.h
@@ -4,9 +4,7 @@
 #include 
 #include 
 
-// TODO: prefix all symbols with "llama_"
-
-struct codepoint_flags {
+struct unicode_cpt_flags {
     enum {
         UNDEFINED       = 0x0001,
         NUMBER          = 0x0002,  // regex: \p{N}
@@ -35,7 +33,7 @@ struct codepoint_flags {
     uint16_t is_nfd         : 1;
 
     // decode from uint16
-    inline codepoint_flags(const uint16_t flags=0) {
+    inline unicode_cpt_flags(const uint16_t flags = 0) {
         *reinterpret_cast(this) = flags;
     }
 
@@ -50,19 +48,20 @@ struct codepoint_flags {
 
 size_t unicode_len_utf8(char src);
 
-std::string unicode_cpt_to_utf8(uint32_t cp);
-uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
+std::string unicode_cpt_to_utf8  (uint32_t cpt);
+uint32_t    unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
+
 std::vector unicode_cpts_from_utf8(const std::string & utf8);
 
 std::vector unicode_cpts_normalize_nfd(const std::vector & cpts);
 
-codepoint_flags unicode_cpt_flags(const uint32_t cp);
-codepoint_flags unicode_cpt_flags(const std::string & utf8);
+unicode_cpt_flags unicode_cpt_flags_from_cpt (uint32_t cpt);
+unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8);
 
 std::string unicode_byte_to_utf8(uint8_t byte);
-uint8_t unicode_utf8_to_byte(const std::string & utf8);
+uint8_t     unicode_utf8_to_byte(const std::string & utf8);
 
-uint32_t unicode_tolower(uint32_t cp);
+uint32_t unicode_tolower(uint32_t cpt);
 
 bool unicode_cpt_is_han(uint32_t cpt);