diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 19887717..8ad02a07 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -70,6 +70,7 @@ func_llm_add_executable(demo_mistral) func_llm_add_executable(demo_yi) func_llm_add_executable(demo_opt) func_llm_add_executable(demo_phi3) +func_llm_add_executable(demo_phi4mini) func_llm_add_executable(demo_minicpm) func_llm_add_executable(demo_minicpm3) func_llm_add_executable(demo_minicpm_moe) diff --git a/examples/demo_phi4mini.cpp b/examples/demo_phi4mini.cpp new file mode 100644 index 00000000..ea34e035 --- /dev/null +++ b/examples/demo_phi4mini.cpp @@ -0,0 +1,64 @@ +#include +#include "cmdline.h" +#include "models/phi4mini/modeling_phi4.hpp" +#include "models/phi4mini/tokenization_phi4mini.hpp" +#include "processor/PostProcess.hpp" + +using namespace mllm; + +int main(int argc, char **argv) { + cmdline::parser cmdParser; + cmdParser.add("vocab", 'v', "specify mllm tokenizer model path", false, + "/data/lyw/phi4-mini/phi4_vocab.mllm"); + cmdParser.add("model", 'm', "specify mllm model path", false, + "/data/lyw/phi4-mini/phi4-mini.mllm"); + cmdParser.add("limits", 'l', "max KV cache size", false, 6000); + cmdParser.add("thread", 't', "num of threads", false, 4); + cmdParser.parse_check(argc, argv); + + string vocab_path = cmdParser.get("vocab"); + string merges_path = "/data/lyw/phi4-mini/merges.txt"; + string model_path = cmdParser.get("model"); + int tokens_limit = cmdParser.get("limits"); + CPUBackend::cpu_threads = cmdParser.get("thread"); + + auto tokenizer = Phi4Tokenizer(vocab_path, merges_path, false); + + Phi4Config config( + tokens_limit, + "4-mini", + HFHUBROPE, + 200064 + ); + auto model = Phi4Model(config); + model.load(model_path); + + vector in_strs = { + "who are you?", + "What can you do?", + "Please introduce Beijing University of Posts and Telecommunications."}; + + for (int i = 0; i < in_strs.size(); ++i) { + auto in_str_origin = in_strs[i]; + auto in_str = tokenizer.apply_chat_template(in_str_origin); + auto input_tensor = tokenizer.tokenize(in_str); + + std::cout << std::endl; + std::cout << "[Q] " << in_str_origin << std::endl; + std::cout << "[A] " << std::flush; + + for (int step = 0; step < 100; ++step) { + auto result = model({input_tensor}); + auto [out_string, out_token] = tokenizer.detokenize(result[0]); + auto [not_end, output_string] = tokenizer.postprocess(out_string); + if (!not_end) { break; } + std::cout << output_string << std::flush; + chatPostProcessing(out_token, input_tensor, {}); + } + printf("\n"); + model.clear_kvcache(); + model.profiling(); + } + + return 0; +} \ No newline at end of file diff --git a/src/Layer.hpp b/src/Layer.hpp index 670ffcea..f59d1906 100644 --- a/src/Layer.hpp +++ b/src/Layer.hpp @@ -970,6 +970,7 @@ class NTKRoPE final : public Layer { for (int i = 0; i < short_factor.size(); i++) { param_["short_factor_" + std::to_string(i)] = short_factor[i]; } + param_["partial_rotary_factor"] = partial_rotary_factor; } Tensor operator()(Tensor input) { diff --git a/src/backends/cpu/op/CPUNTKRoPE.cpp b/src/backends/cpu/op/CPUNTKRoPE.cpp index 60576177..bebbf688 100644 --- a/src/backends/cpu/op/CPUNTKRoPE.cpp +++ b/src/backends/cpu/op/CPUNTKRoPE.cpp @@ -30,10 +30,12 @@ void get_sin_cos_emb_hf( std::vector &long_factor, std::vector &short_factor, int original_max_position_embeddings, + float partial_rotary_factor, int max_position_embeddings = 2048) { auto scale = (float)max_position_embeddings / (float)original_max_position_embeddings; auto scaling_factor = (float)std::sqrt(1 + std::log(scale) / std::log(original_max_position_embeddings)); + output_dim *= partial_rotary_factor; // compute sin and cos emb_sin.resize(seq_len); for (int i = 0; i < seq_len; ++i) { @@ -54,7 +56,7 @@ void get_sin_cos_emb_hf( // calculate inv_freq std::vector inv_freq(output_dim / 2, 0.f); for (int i = 0; i < output_dim / 2; ++i) { - inv_freq[i] = 1.f / (float)(std::pow(theta, (float)i / (float)output_dim)); + inv_freq[i] = 1.f / (float)(std::pow(theta, (float)(i*2) / (float)output_dim)); } std::vector t(seq_len, 0.f); @@ -73,6 +75,9 @@ void get_sin_cos_emb_hf( } } + if (scale <= 1) { + scaling_factor = (float)1; + } for (int i = 0; i < seq_len; ++i) { for (int j = 0; j < output_dim / 2; ++j) { emb_sin[i][j] = std::sin(freqs[i][j]) * scaling_factor; @@ -90,9 +95,10 @@ void apply_rope_hf( std::shared_ptr &output, std::vector> &emb_sin, std::vector> &emb_cos, - int h_cnt) { + int h_cnt, + int partial_dimension) { auto out_dtype = output->dtype(); - int partial_dimension = (input->dimension()) * 1; + //int partial_dimension = (input->dimension()) * 1; int half = (int)(partial_dimension / 2); assert(partial_dimension % 2 == 0); if (output->ctype() == BSHD) { @@ -213,7 +219,8 @@ CPUNTKRoPE::CPUNTKRoPE(Backend *bn, string op_name, int pose_type, float rope_th const std::vector &short_factor, int original_max_position_embeddings, int max_position_embeddings, - int thread_count) : + int thread_count, + float partial_rotary_factor) : Op(bn, op_name), thread_count_(thread_count), pose_type_(pose_type), @@ -221,17 +228,19 @@ CPUNTKRoPE::CPUNTKRoPE(Backend *bn, string op_name, int pose_type, float rope_th long_factor_(long_factor), short_factor_(short_factor), original_max_position_embeddings_(original_max_position_embeddings), - max_position_embeddings_(max_position_embeddings) { + max_position_embeddings_(max_position_embeddings), + partial_rotary_factor_(partial_rotary_factor) { } ErrorCode CPUNTKRoPE::doExecute(std::vector> inputs, std::vector> outputs) { auto &input = inputs[0]; auto &output = outputs[0]; auto out_dtype = output->dtype(); - int partial_dimension = (input->dimension()) * 1; + //int partial_dimension = (input->dimension()) * 1; + int partial_dimension = int(input->dimension() * partial_rotary_factor_); switch ((RoPEType)pose_type_) { case RoPEType::HFHUBROPE: - apply_rope_hf(input, output, emb_sin_, emb_cos_, h_cnt_); + apply_rope_hf(input, output, emb_sin_, emb_cos_, h_cnt_, partial_dimension); break; default: MLLM_LOG_ERROR("RoPEType={} is not supported yet. Currently, only support HFHUBROPE style NTKRoPE", pose_type_); @@ -278,6 +287,7 @@ ErrorCode CPUNTKRoPE::reshape(std::vector> inputs, std:: long_factor_, short_factor_, original_max_position_embeddings_, + partial_rotary_factor_, max_position_embeddings_); break; default: diff --git a/src/backends/cpu/op/CPUNTKRoPE.hpp b/src/backends/cpu/op/CPUNTKRoPE.hpp index 5ee7767c..bfc9dc5e 100644 --- a/src/backends/cpu/op/CPUNTKRoPE.hpp +++ b/src/backends/cpu/op/CPUNTKRoPE.hpp @@ -51,7 +51,8 @@ class CPUNTKRoPE final : public Op { const std::vector &short_factor, int original_max_position_embeddings, int max_position_embeddings, - int thread_count); + int thread_count, + float partial_rotary_factor = 1.0f); ~CPUNTKRoPE() override = default; ErrorCode reshape(std::vector> inputs, std::vector> outputs) override; @@ -76,6 +77,7 @@ class CPUNTKRoPE final : public Op { int max_position_embeddings_ = 32768; int original_max_position_embeddings_ = 32768; int in_shape = -1; + float partial_rotary_factor_ = 1.0f; void clearCache() override { h_cnt_ = 0; @@ -107,8 +109,13 @@ class CPUNTKRoPECreator : public CPUBackend::Creator { int original_max_position_embeddings = static_cast(op_param["original_max_position_embeddings"]); + float partial_rotary_factor = 1.0f; + if (op_param.count("partial_rotary_factor")) { + partial_rotary_factor = op_param["partial_rotary_factor"]; + } + return new CPUNTKRoPE(bn, name, pose_type, rope_theta, long_factor, short_factor, - original_max_position_embeddings, max_position_embeddings, thread_count); + original_max_position_embeddings, max_position_embeddings, thread_count, partial_rotary_factor); } }; diff --git a/src/models/phi4mini/configuration_phi4.hpp b/src/models/phi4mini/configuration_phi4.hpp new file mode 100644 index 00000000..9597f30d --- /dev/null +++ b/src/models/phi4mini/configuration_phi4.hpp @@ -0,0 +1,122 @@ +// +// Created by Lu Yiwen on 2025/6/3 . +// +#ifndef CONFIG_PHI4_HPP +#define CONFIG_PHI4_HPP +#include "models/transformer/configuration_transformer.hpp" + +using namespace mllm; + +class Phi4NameConfig : public TransformerNameConfig { +public: + std::string blk_name; + std::string token_embd_name; + std::string post_norm_name; + std::string lm_head_name; + std::string _gate_up_proj_name; + + void init(RoPEType = HFHUBROPE) { + blk_name = "model.layers."; + _attn_base_name = "self_attn."; + _ffn_base_name = "mlp."; + _qkv_proj_name = "qkv_proj"; + _o_proj_name = "o_proj"; + _gate_up_proj_name = "gate_up_proj"; + _down_proj_name = "down_proj"; + _attn_norm_name = "input_layernorm"; + _ffn_norm_name = "post_attention_layernorm"; + token_embd_name = "model.embed_tokens"; + post_norm_name = "model.norm"; + lm_head_name = token_embd_name; + } +}; + +class Phi4Config : public TransformerConfig { +public: + + int vocab_size{}; + int hidden_dim{}; + int head_size{}; + int num_key_value_heads{}; + int ffn_hidden{}; + int block_num{}; + int max_position_embeddings; + // RoPE + RoPEType RoPE_type; + float rope_theta; + int rope_original_max_position_embeddings; + std::vector rope_long_factor; + std::vector rope_short_factor; + + float attention_dropout; + float rms_norm_eps; + int num_attention_heads; + + int cache_limit{}; + Phi4NameConfig names_config; + bool tie_embedding_words; + bool attention_bias; + float partial_rotary_factor; + + explicit Phi4Config(int token_limit, string billions = "4-mini", RoPEType type = HFHUBROPE, int vocab = 200064) { + names_config.init(type); + + if (billions == "4-mini" || billions == "phi4-mini") { + vocab_size = 200064; + hidden_dim = 3072; // config.hidden_size + head_size = 3072 / 24; // hidden_size/num_attention_heads + num_key_value_heads = 8; // config.num_key_value_heads + ffn_hidden = 8192; // config.intermediate_size + block_num = 32; // config.num_hidden_layers + max_position_embeddings = 131072; // config.original_max_position_embeddings + rope_theta = 10000.0f; // config.rope_theta + + // NEW + num_attention_heads = 24; // config.json.num_attention_heads + attention_dropout = 0.0f; // config.json.attention_dropout + rms_norm_eps = 1e-5f; // config.json.rms_norm_eps + tie_embedding_words = true; + attention_bias = false; + partial_rotary_factor = 0.75; + } else { + throw std::runtime_error("Unsupported model size"); + } + RoPE_type = type; + + rope_original_max_position_embeddings = 4096; + + rope_long_factor = { + 1.0f, 1.118320672f, 1.250641126f, 1.398617824f, + 1.564103225f, 1.74916897f, 1.956131817f, 2.187582649f, + 2.446418898f, 2.735880826f, 3.059592084f, 3.421605075f, + 3.826451687f, 4.279200023f, 4.785517845f, 5.351743533f, + 5.984965424f, 6.693110555f, 7.485043894f, 8.370679318f, + 9.36110372f, 10.4687158f, 11.70738129f, 13.09260651f, + 14.64173252f, 16.37415215f, 18.31155283f, 20.47818807f, + 22.90118105f, 25.61086418f, 28.64115884f, 32.03f, + 32.1f, 32.13f, 32.23f, 32.6f, + 32.61f, 32.64f, 32.66f, 32.7f, + 32.71f, 32.93f, 32.97f, 33.28f, + 33.49f, 33.5f, 44.16f, 47.77f}; + + rope_short_factor = rope_long_factor; + + cache_limit = token_limit; + } + + void validate_rope_scaling() const { + int head_dim = hidden_dim / num_attention_heads; // 3072 / 24 = 128 + int rotary_ndims = head_dim * partial_rotary_factor; // 96 + int expect_len = rotary_ndims / 2; // 48 + if ((int)rope_long_factor.size() != expect_len) { + throw std::runtime_error( + "`rope_long_factor` length must be " + std::to_string(expect_len) + ", but got " + std::to_string(rope_long_factor.size())); + } + if ((int)rope_short_factor.size() != expect_len) { + throw std::runtime_error( + "`rope_short_factor` length must be " + std::to_string(expect_len) + ", but got " + std::to_string(rope_short_factor.size())); + } + } +}; + +#endif // CONFIG_PHI4_HPP diff --git a/src/models/phi4mini/modeling_phi4.hpp b/src/models/phi4mini/modeling_phi4.hpp new file mode 100644 index 00000000..ec3c15eb --- /dev/null +++ b/src/models/phi4mini/modeling_phi4.hpp @@ -0,0 +1,220 @@ +// +// Created by Lu Yiwen on 2025/6/3. +// +#ifndef MODELING_PHI4_HPP +#define MODELING_PHI4_HPP + +#include +#include "Layer.hpp" +#include "Module.hpp" +#include "Tensor.hpp" +#include "configuration_phi4.hpp" +// #include "models/transformer/modeling_transformer.hpp" + +using namespace mllm; + +class Phi4Attention final : public Module { +public: + Phi4Attention() = default; + Phi4Attention(const Phi4Config &config, const Phi4NameConfig &names, const string &base_name) { + hidden_size = config.hidden_dim; + num_heads = config.num_attention_heads; + head_dim = config.hidden_dim / num_heads; + num_key_value_heads = config.num_key_value_heads; + num_key_value_groups = num_heads / num_key_value_heads; + + int head_dim = config.hidden_dim / config.num_attention_heads; // 128 + int rotary_dim = head_dim * config.partial_rotary_factor; // 96 + + qkv_proj = Linear( + hidden_size, + num_heads * head_dim + num_key_value_heads * head_dim * 2, + config.attention_bias, + base_name + names._qkv_proj_name); + o_proj = Linear(num_heads * head_dim, hidden_size, false, base_name + names._o_proj_name); + + q_rope = NTKRoPE( + config.RoPE_type, + config.rope_theta, + config.max_position_embeddings, + config.rope_original_max_position_embeddings, + config.rope_long_factor, + config.rope_short_factor, + base_name + "q_ntkrope", + config.partial_rotary_factor); + k_rope = NTKRoPE( + config.RoPE_type, + config.rope_theta, + config.max_position_embeddings, + config.rope_original_max_position_embeddings, + config.rope_long_factor, + config.rope_short_factor, + base_name + "k_ntkrope", + config.partial_rotary_factor); + k_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, base_name + "k_cache"); + v_cache = KVCache(num_key_value_heads, head_dim, num_key_value_groups, config.cache_limit, base_name + "v_cache"); + // mask = SlidingWindowMask(config.sliding_window, base_name + "mask"); + // mask = Causalmask(base_name + "mask"); + softmax = Softmax(DIMENSION, true, base_name + "softmax"); + } + + std::vector Forward(std::vector inputs, std::vector args) override { + int head_dim = hidden_size / num_heads; + int Q_dim = num_heads * head_dim; // 3072 + int KV_dim = num_key_value_heads * head_dim; // 1024 + int total_proj_dim = Q_dim + 2 * KV_dim; // 5120 + + auto qkv = qkv_proj(inputs[0]); + auto qkv_sp = qkv.split({Q_dim, KV_dim, KV_dim}, Chl::DIMENSION); + auto query_raw = qkv_sp[0]; + auto key_raw = qkv_sp[1]; + auto value_raw = qkv_sp[2]; + + auto query_states = query_raw.view(-1, num_heads, -1, head_dim); + auto key_states = key_raw.view(-1, num_key_value_heads, -1, head_dim); + auto value_states = value_raw.view(-1, num_key_value_heads, -1, head_dim); + + // embedding + query_states = q_rope(query_states); + key_states = k_rope(key_states); + + // kv cache + key_states = k_cache(key_states); + value_states = v_cache(value_states); + + // attention weight + auto atten_weight = + Tensor::mm(query_states, key_states.transpose(Chl::SEQUENCE, Chl::DIMENSION)) + / std::sqrt(head_dim); + // atten_weight = mask(atten_weight, k_cache.getCacheSeqLen()); + atten_weight = softmax(atten_weight, k_cache.getCacheSeqLen()); + + // attention output + auto atten_output = Tensor::mm(atten_weight, value_states); + atten_output = atten_output.view(-1, 1, -1, head_dim * num_heads); + atten_output = o_proj(atten_output); + return {atten_output}; + } + + vector get_cache() { + return {&k_cache, &v_cache}; + } + vector get_rope() { + return {&q_rope, &k_rope}; + } + +private: + int hidden_size; + int num_heads; + int head_dim; + int num_key_value_heads; + int num_key_value_groups; + Layer qkv_proj; + Layer o_proj; + NTKRoPE q_rope; + NTKRoPE k_rope; + // RoPE q_rope; + // RoPE k_rope; + KVCache k_cache; + KVCache v_cache; + // Causalmask mask; + Softmax softmax; +}; + +class Phi4MLP final : public Module { + Layer gate_up_proj; + Layer silu; + Layer down_proj; + int ffn_hidden_; + +public: + Phi4MLP() = default; + Phi4MLP(int hidden_dim, int ffn_hidden, const Phi4NameConfig &names, const string &base_name) { + ffn_hidden_ = ffn_hidden; + gate_up_proj = Linear(hidden_dim, 2 * ffn_hidden, false, base_name + names._gate_up_proj_name); + silu = SiLU(base_name + "act"); + down_proj = Linear(ffn_hidden, hidden_dim, false, base_name + names._down_proj_name); + } + vector Forward(vector inputs, vector args) override { + auto x = gate_up_proj(inputs[0]); + auto splited_y_12 = x.split({ffn_hidden_, ffn_hidden_}, DIMENSION); + auto y_1 = splited_y_12[0]; + Tensor y_2 = splited_y_12[1]; + x = y_2 * silu(y_1); + x = down_proj(x); + return {x}; + } +}; + +class Phi4Block final : public Module { + Phi4Attention attention; + Phi4MLP mlp; + Layer norm1; + Layer norm2; + +public: + Phi4Block() = default; + Phi4Block(const Phi4Config &config, + const Phi4NameConfig &names, + const string &base_name) { + attention = Phi4Attention(config, names, base_name + names._attn_base_name); + mlp = Phi4MLP(config.hidden_dim, config.ffn_hidden, names, base_name + names._ffn_base_name); + norm1 = RMSNorm(config.hidden_dim, 1e-5, base_name + names._attn_norm_name); + norm2 = RMSNorm(config.hidden_dim, 1e-5, base_name + names._ffn_norm_name); + + } + vector Forward(vector inputs, vector args) override { + auto x = norm1(inputs[0]); + + x = attention({x, x, x})[0]; + auto tmp = x + inputs[0]; + x = norm2(tmp); + x = mlp({x})[0]; + x = x + tmp; + return {x}; + } + + Phi4Attention &get_attention() { + return attention; + } +}; + +class Phi4Model final : public Module { + Layer embedding; + vector blocks; + Layer norm; + Layer lm_head; + Parameter lm_head_weight; // 形状 (1, vocab_size, 1, hidden_dim),与 embedding.weight 绑定 + +public: + explicit Phi4Model(const Phi4Config &config) { + embedding = Embedding(config.vocab_size, config.hidden_dim, config.names_config.token_embd_name); + norm = RMSNorm(config.hidden_dim, 1e-6, config.names_config.post_norm_name); + // lm_head = Linear(config.hidden_dim, config.vocab_size, false, config.names_config.lm_head_name); + lm_head_weight = Parameter{1, config.vocab_size, 1, config.hidden_dim, config.names_config.token_embd_name + ".weight"}; + const auto &names = config.names_config; + const std::string base_name = names.blk_name; + blocks = List(config.block_num, config, names, base_name); + } + vector Forward(vector inputs, vector args) override { + auto x = embedding(inputs[0]); + for (auto &block : blocks) { + x = block({x})[0]; + } + x = norm(x); + + x = Tensor::mm(x, lm_head_weight().transpose(Chl::SEQUENCE, Chl::DIMENSION)); + return {x}; + } + + void clear_kvcache() override { + for (auto &block : blocks) { + auto kvcache = block.get_attention().get_cache(); + for (auto &cache : kvcache) { cache->clearCache(); } + auto ropes = block.get_attention().get_rope(); + for (auto &rope : ropes) { rope->clearCache(); } + } + } +}; + +#endif // MODELING_PHI4_HPP \ No newline at end of file diff --git a/src/models/phi4mini/tokenization_phi4mini.hpp b/src/models/phi4mini/tokenization_phi4mini.hpp new file mode 100644 index 00000000..f7cd719b --- /dev/null +++ b/src/models/phi4mini/tokenization_phi4mini.hpp @@ -0,0 +1,224 @@ +// +// Created by Lu Yiwen on 2025/6/3. +// +#ifndef TOKENIZATION_PHI4_HPP +#define TOKENIZATION_PHI4_HPP + +#include "tokenizers/BPE/Bpe.hpp" +#include "tokenizers/Tokenizer.hpp" +#include "tokenizers/Unicode.hpp" +#include +#include + +using namespace mllm; + +#define UTF8(x) any_to_utf8(x) +#define CHR(x) __chr(x) +#define ORD(x) __ord(x) +class Phi4Tokenizer final : public BPETokenizer { +public: + explicit Phi4Tokenizer(const std::string &vocab_file, const std::string &merge_file, bool split_special_tokens = false) : + BPETokenizer(vocab_file), + split_special_tokens_(split_special_tokens) { + Module::initBackend(MLLM_CPU); + + // init byte encoder + std::vector bs; + for (int i = 33 /*!*/; i < 127 /*~*/; ++i) bs.emplace_back(i); + for (int i = 161 /*¡*/; i < 173 /*¬*/; ++i) bs.emplace_back(i); + for (int i = 174 /*®*/; i < 256 /*ÿ*/; ++i) bs.emplace_back(i); + std::vector cs = bs; // this is deep copy + int n = 0; + for (int b = 0; b < 256; ++b) { + if (std::find(bs.begin(), bs.end(), b) == bs.end()) { + bs.push_back(b); + cs.push_back(256 + n); + n++; + } + } + assert(bs.size() == cs.size() && "In init byte encoder, the bs and cs size should be same."); + for (auto i = 0U; i < bs.size(); ++i) { + byte_encoder_[bs[i]] = CHR(cs[i]); + byte_decoder_[CHR(cs[i])] = bs[i]; + } + + // init bpe ranks + auto merge_file_stream = std::ifstream(merge_file); + if (!merge_file_stream.good()) { + std::cout << "merge file is broken\n"; + exit(0); + } + std::string line; + unsigned rank = 0; + while (std::getline(merge_file_stream, line)) { + if (line.empty()) { + continue; + } + if (line[0] == '#') { + continue; + } + bpe_ranks_[line] = rank; + rank++; + } + BPETokenizer::setMergeRank(bpe_ranks_); + chat_template_pre = "<|system|>Your name is Phi, an AI math expert developed by Microsoft.<|end|><|user|>"; + chat_template_end = "<|end|><|assistant|>"; + } + + std::vector stringSplit(const std::string &str, char delim) { + std::size_t previous = 0; + std::size_t current = str.find(delim); + std::vector elems; + while (current != std::string::npos) { + if (current > previous) { + elems.push_back(str.substr(previous, current - previous)); + } + previous = current + 1; + current = str.find(delim, previous); + } + if (previous != str.size()) { + elems.push_back(str.substr(previous)); + } + return elems; + } + + std::vector _splitWithDelimiters(const std::string &str, const std::vector &delimiters) { + std::string s = str; + std::vector result; + size_t pos = 0; + auto isDelimiter = [&](size_t currentPos) { + for (const auto &delimiter : delimiters) { + if (currentPos + delimiter.length() <= s.length() && s.substr(currentPos, delimiter.length()) == delimiter) { + return true; + } + } + return false; + }; + + while (pos < s.length()) { + if (isDelimiter(pos)) { + if (pos != 0) { + result.push_back(s.substr(0, pos)); + } + size_t delimiterLength = delimiters.front().length(); + for (const auto &delimiter : delimiters) { + if (s.substr(pos, delimiter.length()) == delimiter) { + delimiterLength = delimiter.length(); + result.push_back(delimiter); + break; + } + } + pos += delimiterLength; + s = s.substr(pos); + pos = 0; + } else { + ++pos; + } + } + + if (!s.empty()) { + result.push_back(s); + } + + return result; + } + + Tensor tokenize(const std::string &text, string name = "input", BackendType type = MLLM_CPU) override { + std::vector ret; + static const std::vector FIXED_PAT_STRSS = { + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + if (split_special_tokens_) { + const auto word_collection = unicode_regex_split(text, FIXED_PAT_STRSS); + for (auto &piece : word_collection) { + // look up table + // std::string token; + // for (auto b : UTF8(piece)) token += byte_encoder_[b]; + + // using bpe + std::vector tmp; + BPETokenizer::tokenize(piece, tmp, false, true, ""); + ret.insert(ret.end(), tmp.begin(), tmp.end() - 1); + } + } else { + auto parts = _splitWithDelimiters(text, special_tokens); + // for (auto p : parts) { + // std::cout << "\"" << p << "\"" << std::endl; + // } + for (auto &p : parts) { + if (std::find(special_tokens.begin(), special_tokens.end(), p) != special_tokens.end()) { + std::string token; + for (auto b : UTF8(p)) token += byte_encoder_[b]; + + std::vector tmp; + BPETokenizer::tokenize(token, tmp, false, special_tokens, true); + ret.insert(ret.end(), tmp.begin(), tmp.end() - 1); + } else { + const auto word_collection = unicode_regex_split(p, FIXED_PAT_STRS); + for (auto &piece : word_collection) { + // look up table + // std::string token; + // for (auto b : UTF8(piece)) token += byte_encoder_[b]; + + // using bpe + std::vector tmp; + BPETokenizer::tokenize(piece, tmp, false, true, ""); + assert(!tmp.empty()); + ret.insert(ret.end(), tmp.begin(), tmp.end() - 1); + } + } + } + } + std::cout << std::endl; + return Tokenizer::tokens2Input(ret); + } + + // + // padding the input by neareast multiplication of chunk_size + + std::string _byte_decode_(const std::string &text) { + std::string ret; + auto _ = ORD(text); + for (auto i : _) ret += byte_decoder_[CHR(i)]; + return ret; + } + + std::string detokenize(const std::vector &tokens) override { + return _byte_decode_(BPETokenizer::detokenize(tokens)); + } + + std::pair detokenize(Tensor &result) override { + assert(result.batch() == 1); + assert(result.head() == 1); + vector scores; + for (int i = 0; i < result.dimension(); ++i) { + auto value = result.dataAt(0, 0, result.sequence() - 1, i); + scores.push_back(value); + } + auto token_idx = this->argmax(scores); + return {_byte_decode_(BPETokenizer::detokenize({token_idx})), token_idx}; + } + std::pair postprocess(std::string &text) override { + if (text == "<|end|>" || text == "<|endoftext|>") return {false, ""}; + if (text.rfind("<|", 0) == 0) return {true, ""}; + return {true, text}; + } + +public: + bool split_special_tokens_ = false; + std::unordered_map byte_encoder_; + std::unordered_map byte_decoder_; + std::unordered_map bpe_ranks_; + token_id_t eos_id_ = 199999, bos_id_ = 199999; + std::vector special_tokens = { + "<|endoftext|>", + "<|system|>", "<|user|>", "<|assistant|>", "<|end|>", + "<|tool|>", "<|/tool|>", "<|tool_call|>", "<|/tool_call|>", + "<|tool_response|>", "<|tag|>"}; +}; + +#undef UTF8 +#undef CHR +#undef ORD + +#endif //! TOKENIZATION_PHI4_HPP \ No newline at end of file