Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func_llm_add_executable(demo_mistral)
func_llm_add_executable(demo_yi)
func_llm_add_executable(demo_opt)
func_llm_add_executable(demo_phi3)
func_llm_add_executable(demo_phi4mini)
func_llm_add_executable(demo_minicpm)
func_llm_add_executable(demo_minicpm3)
func_llm_add_executable(demo_minicpm_moe)
Expand Down
64 changes: 64 additions & 0 deletions examples/demo_phi4mini.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include <iostream>
#include "cmdline.h"
#include "models/phi4mini/modeling_phi4.hpp"
#include "models/phi4mini/tokenization_phi4mini.hpp"
#include "processor/PostProcess.hpp"

using namespace mllm;

int main(int argc, char **argv) {
cmdline::parser cmdParser;
cmdParser.add<string>("vocab", 'v', "specify mllm tokenizer model path", false,
"/data/lyw/phi4-mini/phi4_vocab.mllm");
cmdParser.add<string>("model", 'm', "specify mllm model path", false,
"/data/lyw/phi4-mini/phi4-mini.mllm");
cmdParser.add<int>("limits", 'l', "max KV cache size", false, 6000);
cmdParser.add<int>("thread", 't', "num of threads", false, 4);
cmdParser.parse_check(argc, argv);

string vocab_path = cmdParser.get<string>("vocab");
string merges_path = "/data/lyw/phi4-mini/merges.txt";
string model_path = cmdParser.get<string>("model");
int tokens_limit = cmdParser.get<int>("limits");
CPUBackend::cpu_threads = cmdParser.get<int>("thread");

auto tokenizer = Phi4Tokenizer(vocab_path, merges_path, false);

Phi4Config config(
tokens_limit,
"4-mini",
HFHUBROPE,
200064
);
auto model = Phi4Model(config);
model.load(model_path);

vector<string> in_strs = {
"who are you?",
"What can you do?",
"Please introduce Beijing University of Posts and Telecommunications."};

for (int i = 0; i < in_strs.size(); ++i) {
auto in_str_origin = in_strs[i];
auto in_str = tokenizer.apply_chat_template(in_str_origin);
auto input_tensor = tokenizer.tokenize(in_str);

std::cout << std::endl;
std::cout << "[Q] " << in_str_origin << std::endl;
std::cout << "[A] " << std::flush;

for (int step = 0; step < 100; ++step) {
auto result = model({input_tensor});
auto [out_string, out_token] = tokenizer.detokenize(result[0]);
auto [not_end, output_string] = tokenizer.postprocess(out_string);
if (!not_end) { break; }
std::cout << output_string << std::flush;
chatPostProcessing(out_token, input_tensor, {});
}
printf("\n");
model.clear_kvcache();
model.profiling();
}

return 0;
}
1 change: 1 addition & 0 deletions src/Layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,7 @@ class NTKRoPE final : public Layer {
for (int i = 0; i < short_factor.size(); i++) {
param_["short_factor_" + std::to_string(i)] = short_factor[i];
}
param_["partial_rotary_factor"] = partial_rotary_factor;
}

Tensor operator()(Tensor input) {
Expand Down
24 changes: 17 additions & 7 deletions src/backends/cpu/op/CPUNTKRoPE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ void get_sin_cos_emb_hf(
std::vector<float> &long_factor,
std::vector<float> &short_factor,
int original_max_position_embeddings,
float partial_rotary_factor,
int max_position_embeddings = 2048) {
auto scale = (float)max_position_embeddings / (float)original_max_position_embeddings;
auto scaling_factor = (float)std::sqrt(1 + std::log(scale) / std::log(original_max_position_embeddings));

output_dim *= partial_rotary_factor;
// compute sin and cos
emb_sin.resize(seq_len);
for (int i = 0; i < seq_len; ++i) {
Expand All @@ -54,7 +56,7 @@ void get_sin_cos_emb_hf(
// calculate inv_freq
std::vector<float> inv_freq(output_dim / 2, 0.f);
for (int i = 0; i < output_dim / 2; ++i) {
inv_freq[i] = 1.f / (float)(std::pow(theta, (float)i / (float)output_dim));
inv_freq[i] = 1.f / (float)(std::pow(theta, (float)(i*2) / (float)output_dim));
}

std::vector<float> t(seq_len, 0.f);
Expand All @@ -73,6 +75,9 @@ void get_sin_cos_emb_hf(
}
}

if (scale <= 1) {
scaling_factor = (float)1;
}
for (int i = 0; i < seq_len; ++i) {
for (int j = 0; j < output_dim / 2; ++j) {
emb_sin[i][j] = std::sin(freqs[i][j]) * scaling_factor;
Expand All @@ -90,9 +95,10 @@ void apply_rope_hf(
std::shared_ptr<Tensor> &output,
std::vector<std::vector<float>> &emb_sin,
std::vector<std::vector<float>> &emb_cos,
int h_cnt) {
int h_cnt,
int partial_dimension) {
auto out_dtype = output->dtype();
int partial_dimension = (input->dimension()) * 1;
//int partial_dimension = (input->dimension()) * 1;
int half = (int)(partial_dimension / 2);
assert(partial_dimension % 2 == 0);
if (output->ctype() == BSHD) {
Expand Down Expand Up @@ -213,25 +219,28 @@ CPUNTKRoPE::CPUNTKRoPE(Backend *bn, string op_name, int pose_type, float rope_th
const std::vector<float> &short_factor,
int original_max_position_embeddings,
int max_position_embeddings,
int thread_count) :
int thread_count,
float partial_rotary_factor) :
Op(bn, op_name),
thread_count_(thread_count),
pose_type_(pose_type),
rope_theta_(rope_theta),
long_factor_(long_factor),
short_factor_(short_factor),
original_max_position_embeddings_(original_max_position_embeddings),
max_position_embeddings_(max_position_embeddings) {
max_position_embeddings_(max_position_embeddings),
partial_rotary_factor_(partial_rotary_factor) {
}

ErrorCode CPUNTKRoPE::doExecute(std::vector<std::shared_ptr<Tensor>> inputs, std::vector<std::shared_ptr<Tensor>> outputs) {
auto &input = inputs[0];
auto &output = outputs[0];
auto out_dtype = output->dtype();
int partial_dimension = (input->dimension()) * 1;
//int partial_dimension = (input->dimension()) * 1;
int partial_dimension = int(input->dimension() * partial_rotary_factor_);
switch ((RoPEType)pose_type_) {
case RoPEType::HFHUBROPE:
apply_rope_hf(input, output, emb_sin_, emb_cos_, h_cnt_);
apply_rope_hf(input, output, emb_sin_, emb_cos_, h_cnt_, partial_dimension);
break;
default:
MLLM_LOG_ERROR("RoPEType={} is not supported yet. Currently, only support HFHUBROPE style NTKRoPE", pose_type_);
Expand Down Expand Up @@ -278,6 +287,7 @@ ErrorCode CPUNTKRoPE::reshape(std::vector<std::shared_ptr<Tensor>> inputs, std::
long_factor_,
short_factor_,
original_max_position_embeddings_,
partial_rotary_factor_,
max_position_embeddings_);
break;
default:
Expand Down
11 changes: 9 additions & 2 deletions src/backends/cpu/op/CPUNTKRoPE.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class CPUNTKRoPE final : public Op {
const std::vector<float> &short_factor,
int original_max_position_embeddings,
int max_position_embeddings,
int thread_count);
int thread_count,
float partial_rotary_factor = 1.0f);

~CPUNTKRoPE() override = default;
ErrorCode reshape(std::vector<std::shared_ptr<Tensor>> inputs, std::vector<std::shared_ptr<Tensor>> outputs) override;
Expand All @@ -76,6 +77,7 @@ class CPUNTKRoPE final : public Op {
int max_position_embeddings_ = 32768;
int original_max_position_embeddings_ = 32768;
int in_shape = -1;
float partial_rotary_factor_ = 1.0f;

void clearCache() override {
h_cnt_ = 0;
Expand Down Expand Up @@ -107,8 +109,13 @@ class CPUNTKRoPECreator : public CPUBackend::Creator {

int original_max_position_embeddings = static_cast<int>(op_param["original_max_position_embeddings"]);

float partial_rotary_factor = 1.0f;
if (op_param.count("partial_rotary_factor")) {
partial_rotary_factor = op_param["partial_rotary_factor"];
}

return new CPUNTKRoPE(bn, name, pose_type, rope_theta, long_factor, short_factor,
original_max_position_embeddings, max_position_embeddings, thread_count);
original_max_position_embeddings, max_position_embeddings, thread_count, partial_rotary_factor);
}
};

Expand Down
122 changes: 122 additions & 0 deletions src/models/phi4mini/configuration_phi4.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
//
// Created by Lu Yiwen on 2025/6/3 .
//
#ifndef CONFIG_PHI4_HPP
#define CONFIG_PHI4_HPP
#include "models/transformer/configuration_transformer.hpp"

using namespace mllm;

class Phi4NameConfig : public TransformerNameConfig {
public:
std::string blk_name;
std::string token_embd_name;
std::string post_norm_name;
std::string lm_head_name;
std::string _gate_up_proj_name;

void init(RoPEType = HFHUBROPE) {
blk_name = "model.layers.";
_attn_base_name = "self_attn.";
_ffn_base_name = "mlp.";
_qkv_proj_name = "qkv_proj";
_o_proj_name = "o_proj";
_gate_up_proj_name = "gate_up_proj";
_down_proj_name = "down_proj";
_attn_norm_name = "input_layernorm";
_ffn_norm_name = "post_attention_layernorm";
token_embd_name = "model.embed_tokens";
post_norm_name = "model.norm";
lm_head_name = token_embd_name;
}
};

class Phi4Config : public TransformerConfig {
public:

int vocab_size{};
int hidden_dim{};
int head_size{};
int num_key_value_heads{};
int ffn_hidden{};
int block_num{};
int max_position_embeddings;
// RoPE
RoPEType RoPE_type;
float rope_theta;
int rope_original_max_position_embeddings;
std::vector<float> rope_long_factor;
std::vector<float> rope_short_factor;

float attention_dropout;
float rms_norm_eps;
int num_attention_heads;

int cache_limit{};
Phi4NameConfig names_config;
bool tie_embedding_words;
bool attention_bias;
float partial_rotary_factor;

explicit Phi4Config(int token_limit, string billions = "4-mini", RoPEType type = HFHUBROPE, int vocab = 200064) {
names_config.init(type);

if (billions == "4-mini" || billions == "phi4-mini") {
vocab_size = 200064;
hidden_dim = 3072; // config.hidden_size
head_size = 3072 / 24; // hidden_size/num_attention_heads
num_key_value_heads = 8; // config.num_key_value_heads
ffn_hidden = 8192; // config.intermediate_size
block_num = 32; // config.num_hidden_layers
max_position_embeddings = 131072; // config.original_max_position_embeddings
rope_theta = 10000.0f; // config.rope_theta

// NEW
num_attention_heads = 24; // config.json.num_attention_heads
attention_dropout = 0.0f; // config.json.attention_dropout
rms_norm_eps = 1e-5f; // config.json.rms_norm_eps
tie_embedding_words = true;
attention_bias = false;
partial_rotary_factor = 0.75;
} else {
throw std::runtime_error("Unsupported model size");
}
RoPE_type = type;

rope_original_max_position_embeddings = 4096;

rope_long_factor = {
1.0f, 1.118320672f, 1.250641126f, 1.398617824f,
1.564103225f, 1.74916897f, 1.956131817f, 2.187582649f,
2.446418898f, 2.735880826f, 3.059592084f, 3.421605075f,
3.826451687f, 4.279200023f, 4.785517845f, 5.351743533f,
5.984965424f, 6.693110555f, 7.485043894f, 8.370679318f,
9.36110372f, 10.4687158f, 11.70738129f, 13.09260651f,
14.64173252f, 16.37415215f, 18.31155283f, 20.47818807f,
22.90118105f, 25.61086418f, 28.64115884f, 32.03f,
32.1f, 32.13f, 32.23f, 32.6f,
32.61f, 32.64f, 32.66f, 32.7f,
32.71f, 32.93f, 32.97f, 33.28f,
33.49f, 33.5f, 44.16f, 47.77f};

rope_short_factor = rope_long_factor;

cache_limit = token_limit;
}

void validate_rope_scaling() const {
int head_dim = hidden_dim / num_attention_heads; // 3072 / 24 = 128
int rotary_ndims = head_dim * partial_rotary_factor; // 96
int expect_len = rotary_ndims / 2; // 48
if ((int)rope_long_factor.size() != expect_len) {
throw std::runtime_error(
"`rope_long_factor` length must be " + std::to_string(expect_len) + ", but got " + std::to_string(rope_long_factor.size()));
}
if ((int)rope_short_factor.size() != expect_len) {
throw std::runtime_error(
"`rope_short_factor` length must be " + std::to_string(expect_len) + ", but got " + std::to_string(rope_short_factor.size()));
}
}
};

#endif // CONFIG_PHI4_HPP
Loading