diff --git a/CMakeLists.txt b/CMakeLists.txt index 645ebb0f..7363940d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,6 +21,11 @@ endif () if (ARM) set(EXECUTABLE_OUTPUT_PATH ${PROJECT_BINARY_DIR}/../bin-arm) + add_compile_definitions(__ARM_FEATURE_DOTPROD) + # 检查是否使用的是 GCC 或 Clang 编译器 + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod") + endif() else () set(EXECUTABLE_OUTPUT_PATH ${PROJECT_BINARY_DIR}/../bin) endif () @@ -96,7 +101,8 @@ endif () if (QUANT) include_directories(${PROJECT_SOURCE_DIR}/src/quantizer) file(GLOB_RECURSE MLLM_QUANT - + ${PROJECT_SOURCE_DIR}/src/backends/cpu/compute/GEMM_AArch64.hpp + ${PROJECT_SOURCE_DIR}/src/backends/cpu/compute/GEMM_AArch64.cpp ${PROJECT_SOURCE_DIR}/src/backends/cpu/quantize/*.hpp ${PROJECT_SOURCE_DIR}/src/backends/cpu/quantize/*.cpp ) diff --git a/examples/demo_imagebind_1mod.cpp b/examples/demo_imagebind_1mod.cpp index 7243b9c7..4c4d782d 100644 --- a/examples/demo_imagebind_1mod.cpp +++ b/examples/demo_imagebind_1mod.cpp @@ -13,53 +13,57 @@ int main(int argc, char **argv) { cmdParser.add("model", 'm', "specify mllm model path", false, "../models/imagebind_huge-q4_k.mllm"); cmdParser.add("merges", 'f', "specify mllm tokenizer merges.txt path", false, "../vocab/clip_merges.txt"); cmdParser.add("thread", 't', "num of threads", false, 4); + cmdParser.add("loop_times", 'l', "number of inference loops", false, 10); + cmdParser.add("modality", 'o', "inference modality (text/vision/audio/all)", false, "all"); cmdParser.parse_check(argc, argv); string vocab_path = cmdParser.get("vocab"); string model_path = cmdParser.get("model"); string merges_path = cmdParser.get("merges"); + int loop_times = cmdParser.get("loop_times"); + string modality = cmdParser.get("modality"); CPUBackend::cpu_threads = cmdParser.get("thread"); auto processor = ImagebindProcessor(vocab_path, merges_path); - ImagebindConfig config("huge"); - int loop_times = 10; - - // auto input_tensors = processor.process( - // {"a dog.", "A car", "A bird"},config.max_position_embeddings, - // {"../assets/dog_image.jpg", "../assets/car_image.jpg", "../assets/bird_image.jpg"}, config.img_hw, - // {"../assets/dog_audio.wav", "../assets/car_audio.wav", "../assets/bird_audio.wav"}); - auto input_tensors = processor.process( - {"a dog."},config.max_position_embeddings, + {"a dog."}, config.max_position_embeddings, {"../assets/dog_image.jpg"}, config.img_hw, {"../assets/dog_audio.wav"}); - - std::cout<<"Text| input_shape:["<("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/qwen_vocab.mllm"); cmdParser.add("merge", 'e', "specify mllm merge file path", false, "../vocab/qwen_merges.txt"); - cmdParser.add("model", 'm', "specify mllm model path", false, "../models/qwen-1.5-0.5b-q4_k.mllm"); + cmdParser.add("model", 'm', "specify mllm model path", false, "../models/qwen-1.5-1.8b-q8_0.mllm"); cmdParser.add("limits", 'l', "max KV cache size", false, 400); cmdParser.add("thread", 't', "num of threads", false, 4); cmdParser.parse_check(argc, argv); @@ -31,7 +31,7 @@ int main(int argc, char **argv) { CPUBackend::cpu_threads = cmdParser.get("thread"); auto tokenizer = QWenTokenizer(vocab_path, merge_path); - QWenConfig config(tokens_limit, "0.5B", RoPEType::HFHUBROPE); + QWenConfig config(tokens_limit, "1.8B", RoPEType::HFHUBROPE); auto model = QWenForCausalLM(config); model.load(model_path); diff --git a/examples/demo_yi.cpp b/examples/demo_yi.cpp index 7de3a784..4aa3e945 100644 --- a/examples/demo_yi.cpp +++ b/examples/demo_yi.cpp @@ -9,9 +9,9 @@ * */ #include "cmdline.h" -#include "models/yi/configuration_yi.hpp" -#include "models/yi/modeling_yi.hpp" -#include "models/yi/tokenization_yi.hpp" +#include "models/llama/configuration_llama.hpp" +#include "models/llama/modeling_llama.hpp" +#include "models/llama/tokenization_llama.hpp" #include "processor/PostProcess.hpp" using namespace mllm; @@ -29,9 +29,9 @@ int main(int argc, char **argv) { int tokens_limit = cmdParser.get("limits"); CPUBackend::cpu_threads = cmdParser.get("thread"); - auto tokenizer = YiTokenizer(vocab_path); - YiConfig config(tokens_limit, "6B", RoPEType::HFHUBROPE); - auto model = YiForCausalLM(config); + auto tokenizer = LLaMATokenizer(vocab_path, false); + LLaMAConfig config(tokens_limit, "6B", RoPEType::HFHUBROPE, 64000); + auto model = LLaMAModel(config); model.load(model_path); vector in_strs = { diff --git a/examples/main_alpaca.cpp b/examples/main_alpaca.cpp index ed53e66a..e598f895 100644 --- a/examples/main_alpaca.cpp +++ b/examples/main_alpaca.cpp @@ -51,8 +51,8 @@ NetTensor *Attention( NetTensor * x, int embedding_size, int hidden_size, int he v = _KVCache( {v}, cache_max, name + ".v_cache"); auto *qk = _Matmul( {q, k}, false, true, name + ".qk"); qk = *qk/std::sqrt(hidden_size); - qk = _Causalmask( {qk}, name + ".mask"); - qk = _Softmax( {qk}, DIMENSION, name + ".softmax"); + // qk = _Causalmask( {qk}, name + ".mask"); + qk = _Softmax( {qk}, DIMENSION, true, name + ".softmax"); auto *o = _Matmul( {qk, v}, false, false, name + ".qkv"); o = o->view(-1, 1, -1, hidden_size * head_size); o = _Linear( {o}, hidden_size * head_size, embedding_size, false, name + ".o_proj"); diff --git a/examples/main_clip.cpp b/examples/main_clip.cpp index 37766831..c5f37bec 100644 --- a/examples/main_clip.cpp +++ b/examples/main_clip.cpp @@ -45,9 +45,11 @@ NetTensor *Attention(NetTensor *x, int embedding_size, int hidden_size, int head auto *qk = _Matmul( {q, k}, false, true, name + ".qk"); qk = _Scale( {qk}, 1.0F / std::sqrt(hidden_size), 0.0F, false, name + ".scale"); if(name.find("text_model") != std::string::npos){ - qk = _Causalmask( {qk}, name + ".mask"); + // qk = _Causalmask( {qk}, name + ".mask"); + qk = _Softmax( {qk}, DIMENSION, true, name + ".softmax"); + } else{ + qk = _Softmax( {qk}, DIMENSION, false, name + ".softmax"); } - qk = _Softmax( {qk}, DIMENSION, name + ".softmax"); auto *o = _Matmul( {qk, v}, false, false, name + ".qkv"); o = o->view(-1, 1, -1, hidden_size * head_size); o = _Linear( {o}, hidden_size * head_size, embedding_size, true, name + ".out_proj"); diff --git a/examples/main_fuyu.cpp b/examples/main_fuyu.cpp index 96b390dd..762e46a1 100644 --- a/examples/main_fuyu.cpp +++ b/examples/main_fuyu.cpp @@ -102,8 +102,8 @@ NetTensor *Attention(NetTensor *x, int embedding_size, int hidden_size, int head v = _KVCache({v}, cache_max, name + ".v_cache"); auto *qk = _Matmul({q, k}, false, true, name + ".qk"); qk = _Scale({qk}, 1.0F / std::sqrt(head_size), 0.0F, false, name + ".scale"); - qk = _Causalmask({qk}, name + ".mask"); - qk = _Softmax({qk}, DIMENSION, name + ".softmax"); + // qk = _Causalmask({qk}, name + ".mask"); + qk = _Softmax({qk}, DIMENSION, true, name + ".softmax"); auto *o = _Matmul({qk, v}, false, false, name + ".qkv"); o = o->view(-1, 1, -1, hidden_size * head_size); o = _Linear({o}, hidden_size * head_size, embedding_size, true, name + ".dense"); diff --git a/examples/main_imagebind.cpp b/examples/main_imagebind.cpp index 13600629..cde79f4a 100644 --- a/examples/main_imagebind.cpp +++ b/examples/main_imagebind.cpp @@ -118,9 +118,10 @@ NetTensor *Attention(Context *c,NetTensor *x, int embedding_size, int hidden_siz auto *qk = _Matmul( {q, k}, false, true, name + ".qk"); qk = *qk/std::sqrt(hidden_size); if(name.find("text") != std::string::npos){ - qk = _Causalmask( {qk}, name + ".mask"); + qk = _Softmax( {qk}, DIMENSION, true, name + ".softmax"); + } else{ + qk = _Softmax( {qk}, DIMENSION, false, name + ".softmax"); } - qk = _Softmax( {qk}, DIMENSION, name + ".softmax"); auto *o = _Matmul( {qk, v}, false, false, name + ".qkv"); o = o->view(-1, 1, -1, hidden_size * head_size); o = _Linear( {o}, hidden_size * head_size, embedding_size, true, name + ".out_proj"); @@ -227,10 +228,10 @@ void ImageBind(Context* c) { a = a->transpose(BATCH, SEQUENCE); auto *j1 = _Matmul( {p, i}, false, true, "final.vision@text"); - j1 = _Softmax( {j1}, DIMENSION, "final.vision@text.softmax"); + j1 = _Softmax( {j1}, DIMENSION, false, "final.vision@text.softmax"); auto *j2 = _Matmul( {p, a}, false, true, "final.vision@audio"); - j2 = _Softmax( {j2}, DIMENSION, "final.vision@audio.softmax"); + j2 = _Softmax( {j2}, DIMENSION, false, "final.vision@audio.softmax"); i = _Cat( {j1, j2}, BATCH, "final.cat"); } diff --git a/examples/main_llama.cpp b/examples/main_llama.cpp index e35c0564..2d847fb8 100644 --- a/examples/main_llama.cpp +++ b/examples/main_llama.cpp @@ -50,8 +50,8 @@ NetTensor *Attention(NetTensor *x, int embedding_size, int hidden_size, int head v = _KVCache({v}, cache_max, name + ".v_cache"); auto *qk = _Matmul({q, k}, false, true, name + ".qk"); qk = *qk / std::sqrt(hidden_size); - qk = _Causalmask({qk}, name + ".mask"); - qk = _Softmax({qk}, DIMENSION, name + ".softmax"); + // qk = _Causalmask({qk}, name + ".mask"); + qk = _Softmax({qk}, DIMENSION, true, name + ".softmax"); auto *o = _Matmul({qk, v}, false, false, name + ".qkv"); o = o->view(-1, 1, -1, hidden_size * head_size); o = _Linear({o}, hidden_size * head_size, embedding_size, false, name + ".wo"); diff --git a/examples/main_llava.cpp b/examples/main_llava.cpp index bec76616..0cad04f9 100644 --- a/examples/main_llava.cpp +++ b/examples/main_llava.cpp @@ -72,8 +72,8 @@ NetTensor *Attention(NetTensor *x, int embedding_size, int hidden_size, int head v = _KVCache({v}, cache_max, name + ".v_cache"); auto *qk = _Matmul({q, k}, false, true, name + ".qk"); qk = *qk / std::sqrt(hidden_size); - qk = _Causalmask({qk}, name + ".mask"); - qk = _Softmax({qk}, DIMENSION, name + ".softmax"); + // qk = _Causalmask({qk}, name + ".mask"); + qk = _Softmax({qk}, DIMENSION, true, name + ".softmax"); auto *o = _Matmul({qk, v}, false, false, name + ".qkv"); o = o->view(-1, 1, -1, hidden_size * head_size); o = _Linear({o}, hidden_size * head_size, embedding_size, false, name + ".o_proj"); @@ -117,9 +117,10 @@ NetTensor *VisionAttention(NetTensor *x, int embedding_size, int hidden_size, in auto *qk = _Matmul({q, k}, false, true, name + ".qk"); qk = _Scale({qk}, 1.0F / std::sqrt(hidden_size), 0.0F, false, name + ".scale"); if (name.find("text_model") != std::string::npos) { - qk = _Causalmask({qk}, name + ".mask"); + qk = _Softmax( {qk}, DIMENSION, true, name + ".softmax"); + } else{ + qk = _Softmax( {qk}, DIMENSION, false, name + ".softmax"); } - qk = _Softmax({qk}, DIMENSION, name + ".softmax"); auto *o = _Matmul({qk, v}, false, false, name + ".qkv"); o = o->view(-1, 1, -1, hidden_size * head_size); o = _Linear({o}, hidden_size * head_size, embedding_size, true, name + ".out_proj"); diff --git a/examples/main_tinyllama.cpp b/examples/main_tinyllama.cpp index bd59eb20..8d9e7cd4 100644 --- a/examples/main_tinyllama.cpp +++ b/examples/main_tinyllama.cpp @@ -51,8 +51,8 @@ NetTensor *Attention( NetTensor * x, int embedding_size, int hidden_size, int he v = _KVCache( {v},head_size/mutil_key_value_head, cache_max, name + ".v_cache"); auto *qk = _Matmul( {q, k}, false, true, name + ".qk"); qk = *qk/std::sqrt(hidden_size); - qk = _Causalmask( {qk}, name + ".mask"); - qk = _Softmax( {qk}, DIMENSION, name + ".softmax"); + // qk = _Causalmask( {qk}, name + ".mask"); + qk = _Softmax( {qk}, DIMENSION, true, name + ".softmax"); auto *o = _Matmul( {qk, v}, false, false, name + ".qkv"); o = o->view(-1, 1, -1, hidden_size * head_size); o = _Linear( {o}, hidden_size * head_size, embedding_size, false, name + ".o_proj"); diff --git a/examples/main_vit.cpp b/examples/main_vit.cpp index b095564b..ba7b0ac2 100644 --- a/examples/main_vit.cpp +++ b/examples/main_vit.cpp @@ -1089,7 +1089,7 @@ NetTensor *Attention(NetTensor * x, int embedded_size, int hidden_size, int head qk = *qk/std::sqrt(hidden_size); // qk = _Scale( {qk}, 1.0F / std::sqrt(hidden_size), 0.0F, false, name + ".scale"); // qk = _Causalmask( {qk}, name + ".mask"); - qk = _Softmax( {qk}, DIMENSION, name + ".softmax"); + qk = _Softmax( {qk}, DIMENSION, false, name + ".softmax"); auto *o = _Matmul( {qk, v}, false, false, name + ".qkv"); o = o->view(-1, 1, -1, hidden_size * head_size); o = _Linear( {o}, hidden_size * head_size, embedded_size, true, name + ".output.dense"); diff --git a/include/Types.hpp b/include/Types.hpp index 4d0fc0ab..a6851f10 100644 --- a/include/Types.hpp +++ b/include/Types.hpp @@ -56,6 +56,10 @@ enum DataType { MLLM_TYPE_I8, MLLM_TYPE_I16, MLLM_TYPE_I32, + MLLM_TYPE_Q4_0_4_4=19, + MLLM_TYPE_Q4_0_4_8=20, + MLLM_TYPE_Q4_0_8_8=21, + MLLM_TYPE_Q8_0_4_4, MLLM_TYPE_COUNT, }; enum ChlType { @@ -147,6 +151,8 @@ enum RoPEType { * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ +// #define LLAMAFILE_SGEMM + #if defined(__ARM_NEON) && !defined(_MSC_VER) typedef __fp16 mllm_fp16_t; #else @@ -223,6 +229,39 @@ typedef struct { #pragma pack() static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K / 16 * sizeof(int16_t), "wrong q8_K block size/padding"); + +#pragma pack(1) +typedef struct { + mllm_fp16_t d[4]; // deltas for 4 q4_0 blocks + uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks +} block_q4_0x4; +#pragma pack() +static_assert(sizeof(block_q4_0x4) == 4 * sizeof(mllm_fp16_t) + QK4_0 * 2, "wrong q4_0x4 block size/padding"); + +#pragma pack(1) +typedef struct { + mllm_fp16_t d[8]; // deltas for 8 q4_0 blocks + uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks +} block_q4_0x8; +#pragma pack() +static_assert(sizeof(block_q4_0x8) == 8 * sizeof(mllm_fp16_t) + QK4_0 * 4, "wrong q4_0x8 block size/padding"); + +#pragma pack(1) +typedef struct { + mllm_fp16_t d[4]; // deltas for 4 q8_0 blocks + int8_t qs[QK8_0 * 4]; // quants for 4 q8_0 blocks +} block_q8_0x4; +#pragma pack() +static_assert(sizeof(block_q8_0x4) == 4 * sizeof(mllm_fp16_t) + QK8_0 * 4, "wrong q8_0x4 block size/padding"); + +#pragma pack(1) +typedef struct { + mllm_fp16_t d[8]; // deltas for 8 q8_0 blocks + int8_t qs[QK8_0 * 8]; // quants for 8 q8_0 blocks +} block_q8_0x8; +#pragma pack() +static_assert(sizeof(block_q8_0x8) == 8 * sizeof(mllm_fp16_t) + QK8_0 * 8, "wrong q8_0x8 block size/padding"); + // static string DataTypeName(DataType dataType) { @@ -251,6 +290,14 @@ static string DataTypeName(DataType dataType) { return "Q4_1"; case MLLM_TYPE_Q8_1: return "Q8_1"; + case MLLM_TYPE_Q4_0_4_4: + return "Q4_0_4_4"; + case MLLM_TYPE_Q4_0_4_8: + return "Q4_0_4_8"; + case MLLM_TYPE_Q4_0_8_8: + return "Q4_0_8_8"; + case MLLM_TYPE_Q8_0_4_4: + return "Q8_0_4_4"; case MLLM_TYPE_COUNT: return "COUNT"; default: @@ -281,6 +328,15 @@ static size_t DataTypeSize(DataType dtype, int count = 1) { return (sizeof(block_q8_K)) * count / (QK_K); case MLLM_TYPE_Q4_1: case MLLM_TYPE_Q8_1: + return -1; + case MLLM_TYPE_Q4_0_4_4: + return (sizeof(block_q4_0x4)) * count / (QK4_0 * 4); + case MLLM_TYPE_Q4_0_4_8: + return (sizeof(block_q4_0x8)) * count / (QK4_0 * 8); + case MLLM_TYPE_Q4_0_8_8: + return (sizeof(block_q4_0x8)) * count / (QK4_0 * 8); + case MLLM_TYPE_Q8_0_4_4: + return (sizeof(block_q8_0x4)) * count / (QK8_0 * 4); case MLLM_TYPE_COUNT: return 0; default: diff --git a/src/Layer.hpp b/src/Layer.hpp index 52f6a6b9..98986a72 100644 --- a/src/Layer.hpp +++ b/src/Layer.hpp @@ -5,6 +5,7 @@ #ifndef OPERATION_H #define OPERATION_H +#include #include #include #include @@ -608,13 +609,24 @@ inline std::map ACT_FN = { class Softmax final : public Layer { public: + Softmax() = default; explicit Softmax(Chl axis, std::string name) { param_["axis"] = axis; init(std::move(name), OpType::SOFTMAX); } + explicit Softmax(Chl axis, bool do_causal_mask, std::string name) { + param_["axis"] = axis; + param_["do_causal_mask"] = do_causal_mask; + init(std::move(name), OpType::SOFTMAX); + } Tensor &operator()(Tensor &input) { return _1I1O_OP(input); } + Tensor &operator()(Tensor &input, int axis_classes) { + auto axis_classes_tensor = Tensor(1, 1, 1, 1, backend_, true); + axis_classes_tensor.setDataAt(0,0,0,0,(float)axis_classes); + return _3I1OO1_OP(input, axis_classes_tensor, axis_classes_tensor); + } }; class Embedding final : public Layer { @@ -631,12 +643,18 @@ class Embedding final : public Layer { class Causalmask final : public Layer { public: + Causalmask() = default; explicit Causalmask(std::string name) { init(std::move(name), OpType::CAUSALMASK); } Tensor &operator()(Tensor &input) { return _1I1O_OP(input); } + Tensor &operator()(Tensor &input0, int kvcache_seq) { + auto kvcache_seq_tensor = Tensor(1, 1, 1, 1, backend_, true); + kvcache_seq_tensor.setDataAt(0,0,0,0,(float)kvcache_seq); + return _3I1OO1_OP(input0, kvcache_seq_tensor, kvcache_seq_tensor); + } }; class SlidingWindowMask final : public Layer { @@ -676,6 +694,7 @@ class RoPE final : public Layer { class KVCache final : public Layer { public: + KVCache() = default; explicit KVCache(int cache_max, std::string name) { param_["n_rep"] = 1; param_["cache_max"] = cache_max; @@ -689,6 +708,9 @@ class KVCache final : public Layer { Tensor &operator()(Tensor &input) { return _1I1O_OP(input); } + int getCacheSeqLen(){ + return op_->getCacheSeqLen(); + } }; class LayerNorm final : public Layer { diff --git a/src/Module.hpp b/src/Module.hpp index ea4e955d..4ad6ff6b 100644 --- a/src/Module.hpp +++ b/src/Module.hpp @@ -84,10 +84,12 @@ class Module { operator()(tmps, args); break; } catch (const std::exception& e) { - if("bad any_cast" != e.what()) { +#if not defined(__ARM_NEON) + if(std::string("bad any_cast") != e.what()) { std::cerr << e.what() << std::endl; exit(0); } +#endif } catch (...) { std::cerr << "load error" << std::endl; exit(0); diff --git a/src/Op.hpp b/src/Op.hpp index fcd6c4ed..73ac11c6 100644 --- a/src/Op.hpp +++ b/src/Op.hpp @@ -3,7 +3,9 @@ // #define DEBUGPRINT #include "Tensor.hpp" #include "Types.hpp" +#include #include +#include #include "ParamLoader.hpp" #include "Timing.hpp" using std::function; @@ -113,6 +115,12 @@ class Op { type_ = type; } + virtual int getCacheSeqLen(){ + assert(type_ == OpType::KVCACHE); + std::cout << "only for KVCache" << std::endl; + return -1; + } + private: Backend *backend_; vector inputs_; diff --git a/src/Tensor.hpp b/src/Tensor.hpp index 98b845e7..2df29717 100644 --- a/src/Tensor.hpp +++ b/src/Tensor.hpp @@ -343,6 +343,59 @@ class Tensor { } } + int memshape(int index) const { + if(master_tensor_) { + if(master_tensor_->master_tensor_) { + return master_tensor_->master_tensor_->shape_[index]; + }else { + return master_tensor_->shape_[index]; + } + }else{ + return shape(index); + } + } + int sequence_skip_dim() const { + if(master_tensor_) { + if(master_tensor_->master_tensor_) { + auto shape = master_tensor_->master_tensor_->shape_; + if (master_tensor_->master_tensor_->ctype_ == BSHD) { + return shape[3]*shape[2]; + } else if (master_tensor_->master_tensor_->ctype_ == BHDS) { + return shape[3]; + } else { + std::cout << "sequence_skip_dim() only support for BSHD and BHDS" << std::endl; + return -1; + } + }else { + auto shape = master_tensor_->shape_; + if (master_tensor_->ctype_ == BSHD) { + return shape[3]*shape[2]; + } else if (master_tensor_->ctype_ == BHDS) { + return shape[3]; + } else { + std::cout << "sequence_skip_dim() only support for BSHD and BHDS" << std::endl; + return -1; + } + } + }else{ + if (ctype_ == BSHD) { + return shape_[3]*shape_[2]; + } else if (ctype_ == BHDS) { + return shape_[3]; + } else if (ctype_ == BDHS) { + return shape_[3]*shape_[2]; + } else if (ctype_ == DBHS) { + return shape_[3]*shape_[2]; + } else if (ctype_ == SBHD) { + return shape_[3]*shape_[2]; + } else { + std::cout << "sequence_skip_dim() only support for BSHD and BHDS" << std::endl; + return -1; + } + // return shape_[3]*shape_[2]; + } + } + /** * \brief obtain the raw pointer to the first address where tensor stores data. * \return the pointer(void *) to the first address where tensor stores data. diff --git a/src/backends/cpu/CPUCausalMask.cpp b/src/backends/cpu/CPUCausalMask.cpp index 2c7c5468..db4cea74 100644 --- a/src/backends/cpu/CPUCausalMask.cpp +++ b/src/backends/cpu/CPUCausalMask.cpp @@ -11,7 +11,7 @@ CPUCausalMask::CPUCausalMask(Backend *bn, string opName, int threadCount) : thre ErrorCode CPUCausalMask::reshape(vector> inputs, vector> outputs) { //std::cout << "CPUMask reshape" << std::endl; - assert(inputs.size() == 1); + // assert(inputs.size() == 1); assert(outputs.size() == 1); outputs[0]->reshape(inputs[0]->batch(), inputs[0]->head(), inputs[0]->sequence(), inputs[0]->dimension()); return Op::reshape(inputs, outputs); @@ -23,12 +23,20 @@ ErrorCode CPUCausalMask::execute(vector> inputs, vectorhead(); int sequence = inputs[0]->sequence(); int dimension = inputs[0]->dimension(); - int old_dim = dimension - sequence; + // memset(outputs[0]->hostPtr(),-INFINITY,outputs[0]->count() * sizeof(float)); + int old_dim = 0; + if (inputs.size()>1) { + old_dim = (int)inputs[1]->dataAt(0,0,0,0)-sequence; + }else{ +#ifndef LLAMAFILE_SGEMM + old_dim = dimension - sequence; +#endif + } +#pragma omp parallel for collapse(4) num_threads(thread_count) for (int n = 0; n < batch_size; ++n) { for (int h = 0; h < head_num; ++h) { for (int s = 0; s < sequence; ++s) { - #pragma omp parallel for num_threads(thread_count) - for (int d = 0; d < dimension; ++d) { + for (int d = 0; d < inputs[0]->dimension(); ++d) { if (d > s + old_dim) { outputs[0]->setDataAt({n, h, s, d}, -INFINITY); } @@ -47,7 +55,7 @@ ErrorCode CPUCausalMask::execute(vector> inputs, vector> inputs, vector> outputs) { - assert(inputs.size() == 1); + // assert(inputs.size() == 1); assert(outputs.size() == 1); if(inputs[0]->masterTensor() == nullptr) { inputs[0]->free(); // TODO remove diff --git a/src/backends/cpu/CPUElasticLinear.cpp b/src/backends/cpu/CPUElasticLinear.cpp index cfd877ba..b2e9e4e5 100644 --- a/src/backends/cpu/CPUElasticLinear.cpp +++ b/src/backends/cpu/CPUElasticLinear.cpp @@ -1,6 +1,6 @@ #include "CPUElasticLinear.hpp" -#include "compute/MatmulElastic.hpp" +#include "compute/Matmul.hpp" namespace mllm { @@ -67,6 +67,8 @@ ErrorCode CPUElasticLinear::execute(vector> inputs, vectorcount() == 0) { return Op::execute(inputs, outputs); } + mat_mul_elastic(inputs[0].get(), &weight_, outputs[0].get(), support_bias_, &bias_, activate_input_dim,activate_output_dim, false, true, thread_count); + /* // std::cout << name() << " CPUElasticLinear()" << std::endl; switch (weight_.dtype()) { case MLLM_TYPE_F32: { @@ -89,6 +91,7 @@ ErrorCode CPUElasticLinear::execute(vector> inputs, vector> inputs, vectorbatch(), inputs[0]->head()*n_rep_, cache_limit_, inputs[0]->dimension()); cache_.setName(name() + ".Cache"); cache_.alloc(); +#ifdef KVCache_TYPE_16 + memset(cache_.hostPtr(),0,cache_.count() * sizeof(mllm_fp16_t)); +#else + memset(cache_.hostPtr(),0,cache_.count() * sizeof(float)); +#endif cache_seq_len_ = 0; } - - outputs[0]->reshape(inputs[0]->batch(), inputs[0]->head()*n_rep_, inputs[0]->sequence() + cache_seq_len_, inputs[0]->dimension()); - if(inputs[0]->sequence() + cache_seq_len_ >cache_limit_){ - std::cerr<<"\n[ERROR]: Current tokens exceed cache limit: "<sequence() + cache_seq_len_<<">"<sequence() + cache_seq_len_; +#ifdef LLAMAFILE_SGEMM + if(sequence%n_pack != 0) + sequence = ((sequence + (n_pack-1)) / n_pack) * n_pack; +#endif + outputs[0]->reshape(inputs[0]->batch(), inputs[0]->head()*n_rep_, sequence, inputs[0]->dimension()); + if(sequence >cache_limit_){ + std::cerr<<"\n[ERROR]: Current tokens exceed cache limit: "<"<"< namespace mllm { @@ -70,7 +71,9 @@ ErrorCode CPULinear::execute(vector> inputs, vectorcount() == 0) { return Op::execute(inputs, outputs); } + mat_mul(inputs[0].get(), &weight_, outputs[0].get(), support_bias_, &bias_, false, true, thread_count); // std::cout << name() << " CPULinear()" << std::endl; + /* switch (weight_.dtype()) { case MLLM_TYPE_F32: { mat_mul_fp32(inputs[0].get(), &weight_, outputs[0].get(), support_bias_, &bias_, false, true, thread_count); @@ -92,6 +95,7 @@ ErrorCode CPULinear::execute(vector> inputs, vector> inputs, vector> inputs, vector> outputs) { assert(inputs[0]->dtype() == MLLM_TYPE_F32); + mat_mul(inputs[0].get(), inputs[1].get(), outputs[0].get(), false, nullptr, transpose0_, transpose1_, thread_count); // assert(inputs[1]->dtype() == MLLM_TYPE_F32); + /* switch (inputs[1]->dtype()) { case MLLM_TYPE_F32: { mat_mul_fp32(inputs[0].get(), inputs[1].get(), outputs[0].get(), false, nullptr, transpose0_, transpose1_, thread_count); @@ -77,6 +79,7 @@ ErrorCode CPUMatmul::execute(vector> inputs, vector> inputs, vectordimension(); int seq = input->sequence(); int head = input->head(); +#pragma omp parallel for collapse(3) num_threads(thread_count) for (int h = 0; h < head; h++) { for (int n = 0; n < batch; n++) { for (int s = 0; s < seq; s++) { diff --git a/src/backends/cpu/CPUPredictor.cpp b/src/backends/cpu/CPUPredictor.cpp index c852fe1a..abede558 100644 --- a/src/backends/cpu/CPUPredictor.cpp +++ b/src/backends/cpu/CPUPredictor.cpp @@ -62,6 +62,9 @@ ErrorCode CPUPredictor::execute(vector> inputs, vectorcount() == 0){ return Op::execute(inputs, outputs); } + mat_mul(x.get(), &up_, &hidden_, false, nullptr, false, true, thread_count); + mat_mul(&hidden_, &down_, o.get(), false, nullptr, false, true, thread_count); + /* switch (up_.dtype()) { // TODO: add support to more type case MLLM_TYPE_F32: { mat_mul_fp32(x.get(), &up_, &hidden_, false, nullptr, false, true, thread_count); @@ -77,6 +80,7 @@ ErrorCode CPUPredictor::execute(vector> inputs, vector #include "CPURMSNorm.hpp" #include "Tensor.hpp" +#include "Timing.hpp" +#include "compute/VecDot.hpp" namespace mllm { @@ -31,39 +33,42 @@ ErrorCode CPURMSNorm::execute(vector> inputs, vectordimension(); int seq = input->sequence(); int head = input->head(); +#pragma omp parallel for collapse(3) num_threads(thread_count) for (int h = 0; h < head; h++) { for (int n = 0; n < batch; n++) { for (int s = 0; s < seq; s++) { double sum_squares = 0.0F; // sum - // #pragma omp parallel for reduction(+ : sum_squares) num_threads(thread_count) for (int d = 0; d < dim; d++) { float value = input->dataAt(n, h, s, d); sum_squares += (double)value * value; } const float mean = sum_squares / dim; const float rms = 1.0f / sqrtf(mean + epsilon_); -// use memset to set the value of the memory block -#pragma omp parallel for num_threads(thread_count) + + memcpy( outputs[0]->ptrAt(n, h, s, 0), + inputs[0]->ptrAt(n, h, s, 0), + dim * sizeof(float)); + vec_scale_f32(dim, outputs[0]->ptrAt(n, h, s, 0), rms); + } + } + } + +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int h = 0; h < head; h++) { + for (int n = 0; n < batch; n++) { + for (int s = 0; s < seq; s++) { for (int d = 0; d < dim; d++) { - float value = input->dataAt(n, h, s, d); float weight = weight_.dataAt(0, 0, 0, d); - float output = value * rms; if (add_unit_offset_) { - output = output * (1 + weight); + *outputs[0]->ptrAt(n, h, s,d) *= (1 + weight); } else { - output = output * weight; + *outputs[0]->ptrAt(n, h, s,d) *= (weight); } - outputs[0]->setDataAt(n, h, s, d, output); } } } } - // input->printData(); - // weight_.printData(); - // outputs[0]->printData(); - - // std::cout << name() << " CPURMSNorm()" << std::endl; return Op::execute(inputs, outputs); } ErrorCode CPURMSNorm::load(AbstructLoader &loader) { diff --git a/src/backends/cpu/CPURoPE.cpp b/src/backends/cpu/CPURoPE.cpp index 1277c282..2b9e2932 100644 --- a/src/backends/cpu/CPURoPE.cpp +++ b/src/backends/cpu/CPURoPE.cpp @@ -1,7 +1,10 @@ #include "CPURoPE.hpp" +#include "Timing.hpp" #include "Types.hpp" +#include #include +#include namespace mllm { @@ -107,15 +110,204 @@ ErrorCode CPURoPE::reshape(vector> inputs, vector input, shared_ptr output){ + auto out_dtype = output->dtype(); + int partial_dimension = (input->dimension()) * partial_rotary_factor_; +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + for (int d = 0; d < partial_dimension; d+=2) { + float in_value = input->dataAt(n, h, s, d); + float in_value_2 = input->dataAt(n, h, s, d + 1); + float sin_value = sin_[s + h_cnt_][d]; + float cos_value = cos_[s + h_cnt_][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, value); + output->setDataAt(n, h, s, d+1, value2); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + output->setDataAt(n, h, s, d+1, MLLM_FP32_TO_FP16(value2)); + } + } + } + } + } +} +void CPURoPE::rope_hf(shared_ptr input, shared_ptr output){ + auto out_dtype = output->dtype(); + int partial_dimension = (input->dimension()) * partial_rotary_factor_; + int half = (int)(partial_dimension / 2); + assert(partial_dimension%2==0); + if(output->ctype() == BSHD){ + if (out_dtype == MLLM_TYPE_F32){ +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + for (int d = 0; d < partial_dimension/2; ++d) { + auto v = input->ptrAt(n, h, s, d); + auto o = output->ptrAt(n, h, s, d); + float in_value = v[0]; + float in_value_2 = v[half]; + float sin_value = sin_[s + h_cnt_][d]; + float cos_value = cos_[s + h_cnt_][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + o[0] = value; + o[half] = value2; + } + } + } + } + }else if(out_dtype == MLLM_TYPE_F16){ +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + for (int d = 0; d < partial_dimension/2; ++d) { + auto v = input->ptrAt(n, h, s, d); + auto o = output->ptrAt(n, h, s, d); + float in_value = v[0]; + float in_value_2 = v[half]; + float sin_value = sin_[s + h_cnt_][d]; + float cos_value = cos_[s + h_cnt_][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + o[0] = MLLM_FP32_TO_FP16(value); + o[half] = MLLM_FP32_TO_FP16(value2); + } + } + } + } + } + return; + } +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + for (int d = 0; d < partial_dimension/2; ++d) { + float in_value = input->dataAt(n, h, s, d); + float in_value_2 = input->dataAt(n, h, s, d + partial_dimension / 2); + float sin_value = sin_[s + h_cnt_][d]; + float cos_value = cos_[s + h_cnt_][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, value); + output->setDataAt(n, h, s, d+ partial_dimension / 2, value2); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + output->setDataAt(n, h, s, d+ partial_dimension / 2, MLLM_FP32_TO_FP16(value2)); + } + } + } + } + } +} +void CPURoPE::rope_permission(shared_ptr input, shared_ptr output){ + auto out_dtype = output->dtype(); + int partial_dimension = (input->dimension()) * partial_rotary_factor_; +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + for (int d = 0; d < partial_dimension; ++d) { + float in_value = input->dataAt(n, h, s, d); + float in_value_2; + float sin_value = sin_[s + h_cnt_][d]; + float cos_value = cos_[s + h_cnt_][d]; + if (d < partial_dimension / 4) { + in_value_2 = -input->dataAt(n, h, s, d + partial_dimension / 4); + auto value = in_value * cos_value + in_value_2 * sin_value; + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, value); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + } + } else if (d < (partial_dimension / 2)) { + in_value_2 = input->dataAt(n, h, s, d - partial_dimension / 4); + auto value = in_value * cos_value + in_value_2 * sin_value; + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, value); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + } + } else { + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, in_value); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(in_value)); + } + } + } + } + } + } +} +void CPURoPE::rope_mla(shared_ptr input, shared_ptr output){ + auto out_dtype = output->dtype(); + int partial_dimension = (input->dimension()) * partial_rotary_factor_; +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + for (int d = 0; d < partial_dimension; ++d) { + int half_dim = input->dimension() / 2; + float in_value = input->dataAt(n, h, s, d); + if (d < half_dim) { + in_value = input->dataAt(n, h, s, d * 2); + } else { + in_value = input->dataAt(n, h, s, 2 *(d - half_dim)+1); + } + float in_value_2; + if (d < half_dim) { + in_value_2 = -input->dataAt(n, h, s, 2 *d+1); + } else { + in_value_2 = input->dataAt(n, h, s, 2 *(d - half_dim)); + } + // no change + float sin_value = sin_[s + h_cnt_][d]; + float cos_value = cos_[s + h_cnt_][d]; + auto value = in_value * cos_value + in_value_2 * sin_value; + if (out_dtype == MLLM_TYPE_F32) { + output->setDataAt(n, h, s, d, value); + } else if (out_dtype == MLLM_TYPE_F16) { + output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); + } + } + } + } + } +} + ErrorCode CPURoPE::execute(vector> inputs, vector> outputs) { auto &input = inputs[0]; auto &output = outputs[0]; auto out_dtype = output->dtype(); int partial_dimension = (input->dimension()) * partial_rotary_factor_; + // auto start_t = mllm_time_us(); + if (pose_type_ == LLAMAROPE) { + rope_llama(input, output); + } else if (pose_type_ == HFHUBROPE) { + rope_hf(input, output); + } else if (pose_type_ == PERSIMMONROPE) { + rope_permission(input, output); + } else if (pose_type_ == MLAROPE) { + rope_mla(input, output); + } else { + std::cerr << "RoPE type error" << std::endl; + + } + /* +#pragma omp parallel for collapse(4) num_threads(thread_count) for (int n = 0; n < input->batch(); ++n) { for (int h = 0; h < input->head(); ++h) { for (int s = 0; s < input->sequence(); ++s) { // sequance -#pragma omp parallel for num_threads(thread_count) for (int d = 0; d < partial_dimension; ++d) { if (pose_type_ == LLAMAROPE) { float in_value = input->dataAt(n, h, s, d); @@ -141,23 +333,23 @@ ErrorCode CPURoPE::execute(vector> inputs, vectordataAt(n, h, s, d + partial_dimension / 4); auto value = in_value * cos_value + in_value_2 * sin_value; - if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F32) { + if (out_dtype == MLLM_TYPE_F32) { output->setDataAt(n, h, s, d, value); - } else if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F16) { + } else if (out_dtype == MLLM_TYPE_F16) { output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); } } else if (d < (partial_dimension / 2)) { in_value_2 = input->dataAt(n, h, s, d - partial_dimension / 4); auto value = in_value * cos_value + in_value_2 * sin_value; - if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F32) { + if (out_dtype == MLLM_TYPE_F32) { output->setDataAt(n, h, s, d, value); - } else if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F16) { + } else if (out_dtype == MLLM_TYPE_F16) { output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); } } else { - if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F32) { + if (out_dtype == MLLM_TYPE_F32) { output->setDataAt(n, h, s, d, in_value); - } else if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F16) { + } else if (out_dtype == MLLM_TYPE_F16) { output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(in_value)); } } @@ -166,7 +358,7 @@ ErrorCode CPURoPE::execute(vector> inputs, vectordataAt(n, h, s, d + partial_dimension / 2); - } else { + // } else { in_value_2 = input->dataAt(n, h, s, d - partial_dimension / 2); } float sin_value = sin_[s + h_cnt_][d]; @@ -195,9 +387,9 @@ ErrorCode CPURoPE::execute(vector> inputs, vectordtypeAt(n, h, s, d) == MLLM_TYPE_F32) { + if (out_dtype == MLLM_TYPE_F32) { output->setDataAt(n, h, s, d, value); - } else if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F16) { + } else if (out_dtype == MLLM_TYPE_F16) { output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); } } else { @@ -207,19 +399,22 @@ ErrorCode CPURoPE::execute(vector> inputs, vectorsequence(); if (h_cnt_ > pos_max_) { h_cnt_ = 0; } +#pragma omp parallel for collapse(4) num_threads(thread_count) for (int n = 0; n < input->batch(); ++n) { for (int h = 0; h < input->head(); ++h) { for (int s = 0; s < input->sequence(); ++s) { -#pragma omp parallel for num_threads(thread_count) for (int d = partial_dimension; d < input->dimension(); ++d) { - if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F32) { + if (out_dtype == MLLM_TYPE_F32) { output->setDataAt(n, h, s, d, input->dataAt(n, h, s, d)); - } else if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F16) { + } else if (out_dtype == MLLM_TYPE_F16) { output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(input->dataAt(n, h, s, d))); } } diff --git a/src/backends/cpu/CPURoPE.hpp b/src/backends/cpu/CPURoPE.hpp index 75813358..058a2746 100644 --- a/src/backends/cpu/CPURoPE.hpp +++ b/src/backends/cpu/CPURoPE.hpp @@ -32,6 +32,12 @@ class CPURoPE final : public Op { int ishape; int thread_count = 4; float partial_rotary_factor_ = 1; + + + void rope_llama(shared_ptr input, shared_ptr output); + void rope_hf(shared_ptr input, shared_ptr output); + void rope_permission(shared_ptr input, shared_ptr output); + void rope_mla(shared_ptr input, shared_ptr output); }; class CPURoPECreator : public CPUBackend::Creator { diff --git a/src/backends/cpu/CPUSiLU.cpp b/src/backends/cpu/CPUSiLU.cpp index 881b6e40..729121e7 100644 --- a/src/backends/cpu/CPUSiLU.cpp +++ b/src/backends/cpu/CPUSiLU.cpp @@ -1,6 +1,7 @@ #include "CPUSiLU.hpp" #include +#include "compute/ActivationFunction.hpp" namespace mllm { diff --git a/src/backends/cpu/CPUSoftMax.cpp b/src/backends/cpu/CPUSoftMax.cpp index 5a4269a7..7dcfcb20 100644 --- a/src/backends/cpu/CPUSoftMax.cpp +++ b/src/backends/cpu/CPUSoftMax.cpp @@ -1,29 +1,16 @@ #include "CPUSoftMax.hpp" #include +#include "Tensor.hpp" #include "quantize/Quantize.hpp" -#include "compute/VecDot.hpp" +#include "compute/ActivationFunction.hpp" namespace mllm { -//static mllm_fp16_t table_exp_f16[1 << 16]; -//static bool init_table_exp_f16_flag = false; -//void init_table_exp_f16() { -// mllm_fp16_t ii; -// for (int i = 0; i < (1 << 16); ++i) { -// uint16_t ui = i; -// memcpy(&ii, &ui, sizeof(ii)); -// const float f = MLLM_COMPUTE_FP16_TO_FP32(ii); -// table_exp_f16[i] = MLLM_FP32_TO_FP16(expf(f)); -// // float val = MLLM_FP16_TO_FP32(expf(f)); -// // std::cout<> inputs, vector> outputs) { // std::cout << name() << " CPUSoftMax reshape" << std::endl; - assert(inputs.size() == 1); + // assert(inputs.size() == 1); outputs[0]->reshape(inputs[0]->batch(), inputs[0]->head(), inputs[0]->sequence(), inputs[0]->dimension()); // outputs[0]->setDtype(activationDtype()); return Op::reshape(inputs, outputs); } -inline static void vec_scale_f32(const int n, float *y, const float v) { - const int np = (n & ~(MLLM_F32_STEP - 1)); - - MLLM_F32_VEC vx = MLLM_F32_VEC_SET1(v); - - MLLM_F32_VEC ay[MLLM_F32_ARR]; - - for (int i = 0; i < np; i += MLLM_F32_STEP) { - for (int j = 0; j < MLLM_F32_ARR; j++) { - ay[j] = MLLM_F32_VEC_LOAD(y + i + j * MLLM_F32_EPR); - ay[j] = MLLM_F32_VEC_MUL(ay[j], vx); - - MLLM_F32_VEC_STORE(y + i + j * MLLM_F32_EPR, ay[j]); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] *= v; - } - // for (int i = 0; i < n; ++i) { - // y[i] *= v; - // } -} ErrorCode CPUSoftMax::execute(vector> inputs, vector> outputs) { // std::cout << name() << " CPUSoftMax()" << std::endl; auto &input = inputs[0]; auto &output = outputs[0]; - + int num_classes_in = -1; + int old_dim = 0; + if (inputs.size()>1) { + num_classes_in = (int)inputs[1]->dataAt(0,0,0,0); + old_dim = num_classes_in -input->sequence(); + }else{ +#ifndef LLAMAFILE_SGEMM + old_dim = input->dimension() - input->sequence(); +#endif + } + memset(output->hostPtr(),0,output->count() * sizeof(float)); if (axis_ == DIMENSION) { + int num_classes = num_classes_in>0? num_classes_in:input->dimension(); // 获取类别数量 +#pragma omp parallel for collapse(3) num_threads(thread_count) for (int n = 0; n < input->batch(); ++n) { - #pragma omp parallel for num_threads(thread_count) for (int h = 0; h < input->head(); ++h) { for (int s = 0; s < input->sequence(); ++s) { - int num_classes = input->dimension(); // 获取类别数量 + int masked_num_classes = num_classes; + if(do_causal_mask_ && input->sequence()>1){ + masked_num_classes = s+1+old_dim; + } float max = -INFINITY; - // #pragma omp parallel for num_threads(thread_count) - for (int j = 0; j < num_classes; ++j) { + for (int j = 0; j < masked_num_classes; ++j) { max = MAX(max, input->dataAt(n, h, s, j)); } float *dp = output->ptrAt(n, h, s, 0); - double sum = 0.0; - uint16_t scvt; - for (int i = 0; i < num_classes; i++) { - if (input->dataAt(n, h, s, i) == -INFINITY) { - dp[i] = 0.0F; - } else { - mllm_fp16_t tmp = MLLM_FP32_TO_FP16(input->dataAt(n, h, s, i) - max); - memcpy(&scvt, &tmp, sizeof(scvt)); - const float val = MLLM_FP16_TO_FP32(table_exp_f16[scvt]); - sum += (double)val; - dp[i] = val; - } - } - + float sum = mllm_vec_soft_max_f32(masked_num_classes, dp, input->ptrAt(n, h, s, 0), max); sum = 1.0 / sum; - vec_scale_f32(num_classes, dp, sum); + vec_scale_f32(masked_num_classes, dp, sum); } } } } else { +#pragma omp parallel for collapse(4) num_threads(thread_count) for (int n = 0; n < input->batch(); ++n) { for (int c = 0; c < input->head(); ++c) { for (int h = 0; h < input->sequence(); ++h) { - // #pragma omp parallel for num_threads(thread_count) for (int w = 0; w < input->dimension(); ++w) { std::vector index = {n, c, h, w}; int num_classes = 0; //input->shape(axis_); // 获取类别数量 @@ -119,8 +82,8 @@ ErrorCode CPUSoftMax::execute(vector> inputs, vectordimension(); break; - } + num_classes = num_classes_in>0? num_classes_in:num_classes; float max = -INFINITY; for (int j = 0; j < num_classes; ++j) { index[axis_] = j; @@ -146,12 +109,14 @@ ErrorCode CPUSoftMax::execute(vector> inputs, vectorsetDataAt(index, softmax_value); } + // for (int i = num_classes; i < input->dimension(); i++) { + // output->setDataAt(index, 0); + // } } } } } } - return Op::execute(inputs, outputs); } diff --git a/src/backends/cpu/CPUSoftMax.hpp b/src/backends/cpu/CPUSoftMax.hpp index af564ced..7a35e508 100644 --- a/src/backends/cpu/CPUSoftMax.hpp +++ b/src/backends/cpu/CPUSoftMax.hpp @@ -8,7 +8,7 @@ namespace mllm { class CPUSoftMax final : public Op { public: - CPUSoftMax(Backend *bn, string opName, int axis, int threadCount); + CPUSoftMax(Backend *bn, string opName, int axis, bool do_causal_mask, int threadCount); virtual ~CPUSoftMax() = default; virtual ErrorCode reshape(vector> inputs, vector> outputs) override; virtual ErrorCode execute(vector> inputs, vector> outputs) override; @@ -16,13 +16,15 @@ class CPUSoftMax final : public Op { private: int axis_ = 0; int thread_count = 4; + bool do_causal_mask_=false; }; class CPUSoftMaxCreator : public CPUBackend::Creator { public: virtual Op *create(OpParam op_param, Backend *bn, string name, int threadCount) const { int axis = op_param["axis"]; - return new CPUSoftMax(bn, name, axis, threadCount); + bool do_causal_mask = op_param["do_causal_mask"]; + return new CPUSoftMax(bn, name, axis, do_causal_mask, threadCount); } }; } // namespace mllm diff --git a/src/backends/cpu/CPUSwaKVCache.cpp b/src/backends/cpu/CPUSwaKVCache.cpp deleted file mode 100644 index 5f962b27..00000000 --- a/src/backends/cpu/CPUSwaKVCache.cpp +++ /dev/null @@ -1,53 +0,0 @@ -/** - * @file CPUSwaKVCache.cpp - * @author Chenghua Wang (chenghua.wang.edu@gmail.com) - * @version 0.1 - * @date 2024-05-01 - * - * @copyright Copyright (c) 2024 - * - */ -#include "CPUSwaKVCache.hpp" - -namespace mllm { - -CPUSwaKVCache::CPUSwaKVCache(Backend *bn, string opName, int n_rep, int window_size, int threadCount) : - n_rep(n_rep), - window_size(window_size), - thread_count(threadCount), - Op(bn, opName) { - cache.setBackend(bn); - cache.setDtype(MLLM_TYPE_F16); -} - -ErrorCode CPUSwaKVCache::reshape(vector> inputs, vector> outputs) { - assert(inputs.size() == 1); - assert(outputs.size() == 1); - if (cache_seq_len < 0) { - cache.reshape(inputs[0]->batch(), inputs[0]->head() * n_rep, window_size, inputs[0]->dimension()); - cache.setName(name() + ".Cache"); - cache.alloc(); - cache_seq_len = 0; - } - - int sequence_len = (inputs[0]->sequence() + cache_seq_len) > window_size ? window_size : (inputs[0]->sequence() + cache_seq_len); - outputs[0]->reshape(inputs[0]->batch(), inputs[0]->head() * n_rep, sequence_len, inputs[0]->dimension()); - return Op::reshape(inputs, outputs); -} - -ErrorCode CPUSwaKVCache::execute(vector> inputs, vector> outputs) { - return Op::execute(inputs, outputs); -} - -ErrorCode CPUSwaKVCache::load(AbstructLoader &loader) { - return Op::load(loader); -} - -ErrorCode CPUSwaKVCache::free(vector> inputs, vector> outputs) { - return Op::free(inputs, outputs); -} - -ErrorCode CPUSwaKVCache::setUp(vector> inputs, vector> outputs) { - return Op::setUp(inputs, outputs); -} -} // namespace mllm diff --git a/src/backends/cpu/CPUSwaKVCache.hpp b/src/backends/cpu/CPUSwaKVCache.hpp deleted file mode 100644 index 7c2bba0b..00000000 --- a/src/backends/cpu/CPUSwaKVCache.hpp +++ /dev/null @@ -1,49 +0,0 @@ -/** - * @file CPUSwaKVCache.hpp - * @author Chenghua Wang (chenghua.wang.edu@gmail.com) - * @brief KV Cache for sliding window attention. - * @version 0.1 - * @date 2024-05-01 - * - * @copyright Copyright (c) 2024 - * - */ -#ifndef MLLM_CPUSWAKVCACHE_H -#define MLLM_CPUSWAKVCACHE_H - -#include "Op.hpp" -#include "CPUBackend.hpp" - -namespace mllm { - -class CPUSwaKVCache final : public Op { -public: - CPUSwaKVCache(Backend *bn, string opName, int n_rep, int window_size, int threadCount); - ~CPUSwaKVCache() override = default; - ErrorCode reshape(vector> inputs, vector> outputs) override; - ErrorCode execute(vector> inputs, vector> outputs) override; - ErrorCode load(AbstructLoader &loader) override; - ErrorCode free(vector> inputs, vector> outputs) override; - ErrorCode setUp(vector> inputs, vector> outputs) override; - -private: - int n_rep = 1; - int window_size; - int thread_count = 4; - int cache_seq_len = -1; - int cur_cache_pos = -1; - Tensor cache; -}; - -class CPUSwaKVCacheCreator : public CPUBackend::Creator { -public: - Op *create(OpParam op_param, Backend *bn, string name, int threadCount) const override { - int n_rep = (int)op_param["n_rep"]; - int window_size = (int)op_param["window_size"]; - return new CPUSwaKVCache(bn, name, n_rep, window_size, threadCount); - } -}; - -} // namespace mllm - -#endif // MLLM_CPUSWAKVCACHE_H diff --git a/src/backends/cpu/CPUTensorFunction.hpp b/src/backends/cpu/CPUTensorFunction.hpp index ecfdd1ac..6b021b5e 100644 --- a/src/backends/cpu/CPUTensorFunction.hpp +++ b/src/backends/cpu/CPUTensorFunction.hpp @@ -67,6 +67,8 @@ class CPUmmFunction: public TensorFunction { void execute(vector outputs, vector inputs, vector args) override { bool isSame = std::equal(inputs[0]->chls().begin(), inputs[0]->chls().end(), inputs[1]->chls().begin()); assert(inputs[0]->dtype() == MLLM_TYPE_F32); + mat_mul(inputs[0], inputs[1], outputs[0], false, nullptr, false, isSame, CPUBackend::cpu_threads); + /* switch (inputs[1]->dtype()) { case MLLM_TYPE_F32: { mat_mul_fp32(inputs[0], inputs[1], outputs[0], false, nullptr, false, isSame, CPUBackend::cpu_threads); @@ -79,6 +81,7 @@ class CPUmmFunction: public TensorFunction { default: break; } + */ } }; diff --git a/src/backends/cpu/compute/ActivationFunction.cpp b/src/backends/cpu/compute/ActivationFunction.cpp new file mode 100644 index 00000000..10482883 --- /dev/null +++ b/src/backends/cpu/compute/ActivationFunction.cpp @@ -0,0 +1,82 @@ +#include "ActivationFunction.hpp" + +namespace mllm { + +void mllm_vec_silu_f32(const int n, float * y, const float * x) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(y + i, mllm_v_silu(_mm512_loadu_ps(x + i))); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(y + i, mllm_v_silu(_mm256_loadu_ps(x + i))); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(y + i, mllm_v_silu(_mm_loadu_ps(x + i))); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, mllm_v_silu(vld1q_f32(x + i))); + } +#endif + for (; i < n; ++i) { + y[i] = mllm_silu_f32(x[i]); + } +} + +float mllm_vec_soft_max_f32(const int n, float * y, const float * x, float max) { + int i = 0; + float sum = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + __m512 val = mllm_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i), + _mm512_set1_ps(max))); + _mm512_storeu_ps(y + i, val); + sum += (float)_mm512_reduce_add_ps(val); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + __m256 val = mllm_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i), + _mm256_set1_ps(max))); + _mm256_storeu_ps(y + i, val); + __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1), + _mm256_castps256_ps128(val)); + val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2)); + val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2)); + sum += (float)_mm_cvtss_f32(val2); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + __m128 val = mllm_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i), + _mm_set1_ps(max))); + _mm_storeu_ps(y + i, val); +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) + val = _mm_add_ps(val, _mm_movehl_ps(val, val)); + val = _mm_add_ss(val, _mm_movehdup_ps(val)); +#else + __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1)); + val = _mm_add_ps(val, tmp); + tmp = _mm_movehl_ps(tmp, val); + val = _mm_add_ss(val, tmp); +#endif + sum += (float)_mm_cvtss_f32(val); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + float32x4_t val = mllm_v_expf(vsubq_f32(vld1q_f32(x + i), + vdupq_n_f32(max))); + vst1q_f32(y + i, val); + sum += (float)vaddvq_f32(val); + } +#endif + for (; i < n; ++i) { + float val = expf(x[i] - max); + sum += (float)val; + y[i] = val; + } + return sum; +} + +} // namespace mllm \ No newline at end of file diff --git a/src/backends/cpu/compute/ActivationFunction.hpp b/src/backends/cpu/compute/ActivationFunction.hpp new file mode 100644 index 00000000..be748f48 --- /dev/null +++ b/src/backends/cpu/compute/ActivationFunction.hpp @@ -0,0 +1,207 @@ + +#ifndef ACTFUNC_HPP +#define ACTFUNC_HPP + +#include "quantize/Quantize.hpp" +#include "compute/VecDot.hpp" +namespace mllm { + +#if defined(__ARM_NEON) && defined(__aarch64__) + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline static float32x4_t mllm_v_expf(float32x4_t x) { + const float32x4_t r = vdupq_n_f32(0x1.8p23f); + const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f)); + const float32x4_t n = vsubq_f32(z, r); + const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n, + vdupq_n_f32(0x1.7f7d1cp-20f)); + const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23); + const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1)))); + const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126)); + const float32x4_t u = vmulq_f32(b, b); + const float32x4_t j = vfmaq_f32( + vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b), + vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b), + vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u); + if (!vpaddd_u64(vreinterpretq_u64_u32(c))) + return vfmaq_f32(k, j, k); + const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000)); + const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000))); + const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d)); + return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1), + vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j))); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static float32x4_t mllm_v_silu(float32x4_t x) { + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t zero = vdupq_n_f32(0.0f); + const float32x4_t neg_x = vsubq_f32(zero, x); + const float32x4_t exp_neg_x = mllm_v_expf(neg_x); + const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); + return vdivq_f32(x, one_plus_exp_neg_x); +} + +#elif defined(__AVX512F__) && defined(__AVX512DQ__) + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline static __m512 mllm_v_expf(__m512 x) { + const __m512 r = _mm512_set1_ps(0x1.8p23f); + const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r); + const __m512 n = _mm512_sub_ps(z, r); + const __m512 b = + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); + const __mmask16 d = + _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ); + const __m512 u = _mm512_mul_ps(b, b); + const __m512 j = _mm512_fmadd_ps( + _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, + _mm512_set1_ps(0x1.573e2ep-5f)), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, + _mm512_set1_ps(0x1.fffdb6p-2f))), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F))); + const __m512 res = _mm512_scalef_ps(j, n); + if (_mm512_kortestz(d, d)) + return res; + const __m512 zero = _mm512_setzero_ps(); + const __m512 alt = _mm512_mask_blend_ps( + _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero); + return _mm512_mask_blend_ps(d, res, alt); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static __m512 mllm_v_silu(__m512 x) { + const __m512 one = _mm512_set1_ps(1); + const __m512 zero = _mm512_setzero_ps(); + const __m512 neg_x = _mm512_sub_ps(zero, x); + const __m512 exp_neg_x = mllm_v_expf(neg_x); + const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); + return _mm512_div_ps(x, one_plus_exp_neg_x); +} + +#elif defined(__AVX2__) && defined(__FMA__) + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline static __m256 mllm_v_expf(__m256 x) { + const __m256 r = _mm256_set1_ps(0x1.8p23f); + const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r); + const __m256 n = _mm256_sub_ps(z, r); + const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f), + _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x)); + const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23); + const __m256 k = _mm256_castsi256_ps( + _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1)))); + const __m256i c = _mm256_castps_si256( + _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), + _mm256_set1_ps(126), _CMP_GT_OQ)); + const __m256 u = _mm256_mul_ps(b, b); + const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b, + _mm256_set1_ps(0x1.573e2ep-5f)), u, + _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b, + _mm256_set1_ps(0x1.fffdb6p-2f))), + u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b)); + if (!_mm256_movemask_ps(_mm256_castsi256_ps(c))) + return _mm256_fmadd_ps(j, k, k); + const __m256i g = _mm256_and_si256( + _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)), + _mm256_set1_epi32(0x82000000u)); + const __m256 s1 = + _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u))); + const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g)); + const __m256i d = _mm256_castps_si256( + _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), + _mm256_set1_ps(192), _CMP_GT_OQ)); + return _mm256_or_ps( + _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)), + _mm256_andnot_ps( + _mm256_castsi256_ps(d), + _mm256_or_ps( + _mm256_and_ps(_mm256_castsi256_ps(c), + _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)), + _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k))))); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static __m256 mllm_v_silu(__m256 x) { + const __m256 one = _mm256_set1_ps(1); + const __m256 zero = _mm256_setzero_ps(); + const __m256 neg_x = _mm256_sub_ps(zero, x); + const __m256 exp_neg_x = mllm_v_expf(neg_x); + const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); + return _mm256_div_ps(x, one_plus_exp_neg_x); +} + +#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON + +#if defined(__FMA__) +#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z) +#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z) +#else +#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z) +#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y)) +#endif + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline static __m128 mllm_v_expf(__m128 x) { + const __m128 r = _mm_set1_ps(0x1.8p23f); + const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r); + const __m128 n = _mm_sub_ps(z, r); + const __m128 b = + NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x)); + const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23); + const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1)))); + const __m128i c = + _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126))); + const __m128 u = _mm_mul_ps(b, b); + const __m128 j = + MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u, + MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))), + u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b)); + if (!_mm_movemask_epi8(c)) + return MADD128(j, k, k); + const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())), + _mm_set1_epi32(0x82000000u)); + const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u))); + const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g)); + const __m128i d = + _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192))); + return _mm_or_ps( + _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)), + _mm_andnot_ps(_mm_castsi128_ps(d), + _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)), + _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k))))); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static __m128 mllm_v_silu(__m128 x) { + const __m128 one = _mm_set1_ps(1); + const __m128 zero = _mm_setzero_ps(); + const __m128 neg_x = _mm_sub_ps(zero, x); + const __m128 exp_neg_x = mllm_v_expf(neg_x); + const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x); + return _mm_div_ps(x, one_plus_exp_neg_x); +} + +#endif // __ARM_NEON / __AVX2__ / __SSE2__ + +void mllm_vec_silu_f32(const int n, float * y, const float * x); + +float mllm_vec_soft_max_f32(const int n, float * y, const float * x, float max); + +} // namespace mllm +#endif //ACTFUNC_HPP \ No newline at end of file diff --git a/src/backends/cpu/compute/GEMM_AArch64.cpp b/src/backends/cpu/compute/GEMM_AArch64.cpp new file mode 100644 index 00000000..c51696d7 --- /dev/null +++ b/src/backends/cpu/compute/GEMM_AArch64.cpp @@ -0,0 +1,2211 @@ +#include "GEMM_AArch64.hpp" +#include "Types.hpp" +#include +#include +#include +#include +#include // for qsort +#include // for assert + + +int mllm_cpu_has_sve(void) { +#if defined(__ARM_FEATURE_SVE) + return 1; +#else + return 0; +#endif +} + +int mllm_cpu_has_matmul_int8(void) { +#if defined(__ARM_FEATURE_MATMUL_INT8) + return 1; +#else + return 0; +#endif +} + +// Functions to create the interleaved data layout formats + +// interleave 4 block_q4_0s in blocks of blck_size_interleave +// returns an interleaved block_q4_0x4 +// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks +// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave +// +// - in : an array of block_q4_0 pointers +// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of +// blck_size_interleave bytes +// - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes +// from bias offset form to pure sign form (this saves subtract +// operations durin unpacking) +// +static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { + block_q4_0x4 out; + + for (int i = 0; i < 4; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < QK4_0 * 2; i++) { + int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (i % blck_size_interleave); + + out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; + } + + return out; +} + +// interleave 8 block_q4_0s in blocks of blck_size_interleave +// returns an interleaved block_q4_0x8 +// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks +// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave +static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { + block_q4_0x8 out; + + for (int i = 0; i < 8; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < QK4_0 * 4; i++) { + int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave; + int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave; + src_offset += (i % blck_size_interleave); + + out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; + } + + return out; +} + +void quantize_q8_0_4x4(const float * __restrict x, void * __restrict vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * __restrict y = (block_q8_0x4 *) vy; + +#if defined(__ARM_NEON) + float32x4_t srcv[4][8]; + float id[4]; + + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = MLLM_FP32_TO_FP16(d); + } + + for (int j = 0; j < 8; j++) { + float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]); + int32x4_t vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[1][j], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[2][j], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[3][j], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3); + } + } +#else + // scalar + const int blck_size_interleave = 4; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = MLLM_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +#endif +} + +void quantize_q8_0_4x8(const float * __restrict x, void * __restrict vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * __restrict y = (block_q8_0x4 *) vy; + +#if defined(__ARM_NEON) + float32x4_t srcv[4][8]; + float id[4]; + + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = MLLM_FP32_TO_FP16(d); + } + + for (int j = 0; j < 4; j++) { + float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]); + int32x4_t vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[1][2 * j], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[2][2 * j], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[3][2 * j], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3); + } + } +#else + // scalar + const int blck_size_interleave = 8; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = MLLM_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +#endif +} + +void quantize_mat_q8_0(const float * __restrict x, void * __restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { + assert(nrow == 4); + (void)nrow; + if (blck_size_interleave == 4) { + quantize_q8_0_4x4(x, vy, n_per_row); + } else if (blck_size_interleave == 8) { + quantize_q8_0_4x8(x, vy, n_per_row); + } else { + assert(false); + } +} + +static size_t quantize_q4_0_nr_bl(const float * __restrict src, void * __restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) { + assert(n_per_row % QK4_0 == 0); + const int nb = n_per_row / QK4_0; + + void * out_ptr = NULL; + if (nrows_interleaved == 8) { + out_ptr = (block_q4_0x8 *) dst; + } + else if (nrows_interleaved == 4) { + out_ptr = (block_q4_0x4 *) dst; + } + assert(nrows_interleaved <= 8); + block_q4_0 dst_tmp[8]; + + for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) { + + for (int64_t x = 0; x < nb; x++) { + + for (int i = 0; i < nrows_interleaved; i++ ) { + quantize_row_q4_0(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0); + } + + if (nrows_interleaved == 8) { + *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave, 0x88); + out_ptr = (block_q4_0x8 *) out_ptr + 1; + } + else if (nrows_interleaved == 4) { + *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave, 0x88); + out_ptr = (block_q4_0x4 *) out_ptr + 1; + } + } + } + + return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0)); +} + +size_t quantize_q4_0_4x4(const float * __restrict src, void * __restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4); + } + else { + assert(false); + return 0; + } +} + +size_t quantize_q4_0_4x8(const float * __restrict src, void * __restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8); + } + else { + assert(false); + return 0; + } +} + +size_t quantize_q4_0_8x8(const float * __restrict src, void * __restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8); + } + else { + assert(false); + return 0; + } +} + +void mllm_gemv_q4_0_4x4_q8_0(int n, float * __restrict s, size_t bs, const void * __restrict vx, const void * __restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + (void)s; + (void)bs; + (void)vx; + (void)vy; + (void)nr; + (void)nc; + (void)nb; + (void)ncols_interleaved; + (void)blocklen; + +#if defined(__ARM_FEATURE_SVE) + if (svcntw() == 8) { + assert(!(mllm_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + assert(!(mllm_cpu_has_neon() && mllm_cpu_has_matmul_int8()) && + "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance"); +#elif defined(__ARM_NEON) && defined(__aarch64__) + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "movi v31.16b, #0x4\n" + "movi v30.16b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x8\n" + "1:" // Column loop + "add x22, %x[a_ptr], #0x2\n" + "movi v29.16b, #0x0\n" + "mov x21, %x[nb]\n" + "2:" // Block loop + "ldr q28, [%x[b_ptr], #0x0]\n" + "ldr q27, [x22, #0x0]\n" + "movi v26.4s, #0x0\n" + "sub x20, x22, #0x2\n" + "ldr q25, [x22, #0x10]\n" + "ldr q24, [%x[b_ptr], #0x10]\n" + "sub x21, x21, #0x1\n" + "add x22, x22, #0x22\n" + "ldr q23, [%x[b_ptr], #0x20]\n" + "ldr q22, [%x[b_ptr], #0x30]\n" + "ld1r { v21.8h }, [x20]\n" + "ldr q20, [%x[b_ptr], #-0x8]\n" + "sshl v16.16b, v28.16b, v31.16b\n" + "and v28.16b, v28.16b, v30.16b\n" + "sshl v19.16b, v24.16b, v31.16b\n" + "and v24.16b, v24.16b, v30.16b\n" + "add %x[b_ptr], %x[b_ptr], #0x48\n" + "sshl v18.16b, v23.16b, v31.16b\n" + "and v23.16b, v23.16b, v30.16b\n" + ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n" + "sshl v17.16b, v22.16b, v31.16b\n" + "and v22.16b, v22.16b, v30.16b\n" + "fcvtl v21.4s, v21.4h\n" + "fcvtl v16.4s, v20.4h\n" + ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n" + "fmul v16.4s, v16.4s, v21.4s\n" + ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n" + ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n" + ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n" + ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n" + ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n" + ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "fmla v29.4s, v26.4s, v16.4s\n" + "cbnz x21, 2b\n" + "sub %x[nc], %x[nc], #0x4\n" + "str q29, [%x[res_ptr], #0x0]\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22" + ); +#else + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * MLLM_FP16_TO_FP32(b_ptr[l].d[j]) * MLLM_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +#endif +} + +void mllm_gemv_q4_0_4x8_q8_0(int n, float * __restrict s, size_t bs, const void * __restrict vx, const void * __restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + (void)s; + (void)bs; + (void)vx; + (void)vy; + (void)nr; + (void)nc; + (void)nb; + (void)ncols_interleaved; + (void)blocklen; + +#if defined(__ARM_FEATURE_SVE) + if (svcntw() == 8) { + assert(!(mllm_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "movi v2.16b, #0x4\n" + "movi v1.16b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x8\n" + "1:" // Column loop + "add x23, %x[a_ptr], #0x2\n" + "movi v0.16b, #0x0\n" + "mov x22, %x[nb]\n" + "2:" // Block loop + "ldr q31, [%x[b_ptr], #0x0]\n" + "ldr q30, [%x[b_ptr], #0x10]\n" + "mov x21, x23\n" + "movi v29.4s, #0x0\n" + "ldr q28, [%x[b_ptr], #0x20]\n" + "ldr q27, [%x[b_ptr], #0x30]\n" + "movi v26.4s, #0x0\n" + "sub x20, x23, #0x2\n" + "ld1r { v25.8h }, [x20]\n" + "ldr q24, [%x[b_ptr], #-0x8]\n" + "sub x22, x22, #0x1\n" + "add x23, x23, #0x22\n" + "ld1r { v23.2d }, [x21], #0x8\n" + "sshl v22.16b, v31.16b, v2.16b\n" + "sshl v16.16b, v30.16b, v2.16b\n" + "add %x[b_ptr], %x[b_ptr], #0x48\n" + "ld1r { v21.2d }, [x21], #0x8\n" + "sshl v20.16b, v28.16b, v2.16b\n" + "sshl v19.16b, v27.16b, v2.16b\n" + "ld1r { v18.2d }, [x21], #0x8\n" + "ld1r { v17.2d }, [x21], #0x8\n" + "and v31.16b, v31.16b, v1.16b\n" + "and v30.16b, v30.16b, v1.16b\n" + ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n" + ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n" + "and v28.16b, v28.16b, v1.16b\n" + "and v27.16b, v27.16b, v1.16b\n" + "fcvtl v25.4s, v25.4h\n" + "fcvtl v16.4s, v24.4h\n" + ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n" + ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n" + "fmul v16.4s, v16.4s, v25.4s\n" + ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n" + ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n" + ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n" + ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n" + "addp v29.4s, v29.4s, v26.4s\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "fmla v0.4s, v29.4s, v16.4s\n" + "cbnz x22, 2b\n" + "sub %x[nc], %x[nc], #0x4\n" + "str q0, [%x[res_ptr], #0x0]\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23" + ); +#elif defined(__ARM_NEON) && defined(__aarch64__) + assert((mllm_cpu_has_sve() || mllm_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " + "performance"); +#else + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * MLLM_FP16_TO_FP32(b_ptr[l].d[j]) * MLLM_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +#endif +} + +void mllm_gemv_q4_0_8x8_q8_0(int n, float * __restrict s, size_t bs, const void * __restrict vx, const void * __restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + (void)s; + (void)bs; + (void)vx; + (void)vy; + (void)nr; + (void)nc; + (void)nb; + (void)ncols_interleaved; + (void)blocklen; + +#if defined(__ARM_FEATURE_SVE) + if (svcntw() == 8) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "ptrue p0.b\n" + "add %x[b_ptr], %x[b_ptr], #0x10\n" + "1:" // Column loop + "add x22, %x[a_ptr], #0x2\n" + "mov z31.b, #0x0\n" + "mov x21, %x[nb]\n" + "2:" // Block loop + "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n" + "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n" + "mov z28.s, #0x0\n" + "mov z27.s, #0x0\n" + "ld1rd { z26.d }, p0/Z, [x22]\n" + "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n" + "sub x20, x22, #0x2\n" + "sub x21, x21, #0x1\n" + "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n" + "ld1rd { z23.d }, p0/Z, [x22, #8]\n" + "lsl z22.b, z30.b, #0x4\n" + "lsl z16.b, z29.b, #0x4\n" + "and z30.b, z30.b, #0xf0\n" + "and z29.b, z29.b, #0xf0\n" + "ld1rd { z21.d }, p0/Z, [x22, #16]\n" + "ld1rd { z20.d }, p0/Z, [x22, #24]\n" + "lsl z19.b, z25.b, #0x4\n" + "and z25.b, z25.b, #0xf0\n" + "ld1rh { z17.h }, p0/Z, [x20]\n" + "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n" + "sdot z28.s, z22.b, z26.b\n" + "sdot z27.s, z16.b, z26.b\n" + "lsl z16.b, z24.b, #0x4\n" + "add x22, x22, #0x22\n" + "and z24.b, z24.b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x90\n" + "fcvt z17.s, p0/m, z17.h\n" + "fcvt z18.s, p0/m, z18.h\n" + "sdot z28.s, z19.b, z23.b\n" + "sdot z27.s, z16.b, z23.b\n" + "fmul z18.s, z18.s, z17.s\n" + "sdot z28.s, z30.b, z21.b\n" + "sdot z27.s, z29.b, z21.b\n" + "sdot z28.s, z25.b, z20.b\n" + "sdot z27.s, z24.b, z20.b\n" + "uzp1 z17.s, z28.s, z27.s\n" + "uzp2 z16.s, z28.s, z27.s\n" + "add z17.s, z17.s, z16.s\n" + "asr z17.s, z17.s, #0x4\n" + "scvtf z17.s, p0/m, z17.s\n" + "fmla z31.s, p0/M, z17.s, z18.s\n" + "cbnz x21, 2b\n" + "sub %x[nc], %x[nc], #0x8\n" + "st1w { z31.s }, p0, [%x[res_ptr]]\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); + return; + } + else if (mllm_cpu_has_neon() && mllm_cpu_has_matmul_int8()) { + assert((mllm_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal " + "performance"); + } + else if (mllm_cpu_has_neon()) { + assert(((mllm_cpu_has_sve() && (svcntw() == 8)) || mllm_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 " + "quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + assert(mllm_cpu_has_sve() && + "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance"); +#elif defined(__ARM_NEON) && defined(__aarch64__) + assert((mllm_cpu_has_sve() || mllm_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " + "performance"); +#else + float sumf[8]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * MLLM_FP16_TO_FP32(b_ptr[l].d[j]) * MLLM_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +#endif +} + +void mllm_gemm_q4_0_4x4_q8_0(int n, float * __restrict s, size_t bs, const void * __restrict vx, const void * __restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + (void)s; + (void)bs; + (void)vx; + (void)vy; + (void)nr; + (void)nc; + (void)nb; + (void)ncols_interleaved; + (void)blocklen; + +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntw() == 8) { + assert(!(mllm_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + assert(!(mllm_cpu_has_neon() && mllm_cpu_has_matmul_int8()) && + "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance"); +#elif defined(__ARM_NEON) && defined(__aarch64__) + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x10, %x[nr]\n" + "mov x9, #0x88\n" + "cmp x10, #0x10\n" + "mul x9, %x[nb], x9\n" + "blt 4f\n" + "1:" // Row loop + "add x28, %x[b_ptr], #0x8\n" + "mov x27, %x[nc]\n" + "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x25, %x[a_ptr], #0x8\n" + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "mov x24, %x[nb]\n" + "add x23, x25, x9\n" + "movi v18.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "add x22, x23, x9\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x21, x22, x9\n" + "movi v23.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v0.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v8.16b, #0x0\n" + "movi v1.16b, #0x0\n" + "3:" // Block loop + "ldr q3, [x28, #0x0]\n" + "ldr q31, [x25, #0x0]\n" + "movi v28.16b, #0x4\n" + "movi v10.4s, #0x0\n" + "ldr q22, [x28, #0x10]\n" + "ldr q6, [x25, #0x10]\n" + "movi v29.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "ldr q27, [x28, #0x20]\n" + "ldr q30, [x28, #0x30]\n" + "movi v20.4s, #0x0\n" + "movi v24.16b, #0xf0\n" + "ldr d2, [x25, #-0x8]\n" + "ldr d26, [x23, #-0x8]\n" + "sshl v12.16b, v3.16b, v28.16b\n" + "sub x20, x28, #0x8\n" + "ldr d17, [x20, #0x0]\n" + "and v3.16b, v3.16b, v24.16b\n" + "subs x24, x24, #0x1\n" + "add x28, x28, #0x48\n" + ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n" + ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n" + ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n" + ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n" + "sshl v31.16b, v22.16b, v28.16b\n" + "and v22.16b, v22.16b, v24.16b\n" + "fcvtl v17.4s, v17.4h\n" + "fcvtl v2.4s, v2.4h\n" + "fcvtl v26.4s, v26.4h\n" + ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n" + ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n" + ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n" + ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n" + "sshl v6.16b, v27.16b, v28.16b\n" + "sshl v28.16b, v30.16b, v28.16b\n" + "and v27.16b, v27.16b, v24.16b\n" + "and v30.16b, v30.16b, v24.16b\n" + "ldr q24, [x25, #0x20]\n" + ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x30]\n" + ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n" + ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n" + ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n" + ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x40]\n" + ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x50]\n" + ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n" + ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n" + ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n" + ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x60]\n" + ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x70]\n" + "add x25, x25, #0x88\n" + ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n" + ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n" + ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n" + ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n" + "fmul v24.4s, v17.4s, v2.s[0]\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v15.4s, v10.4s, v24.4s\n" + "ldr q24, [x23, #0x0]\n" + "fmul v10.4s, v17.4s, v2.s[1]\n" + "fmla v19.4s, v29.4s, v10.4s\n" + "ldr q10, [x23, #0x10]\n" + "fmul v29.4s, v17.4s, v2.s[2]\n" + "fmul v2.4s, v17.4s, v2.s[3]\n" + "fmla v18.4s, v9.4s, v29.4s\n" + "movi v9.4s, #0x0\n" + "movi v29.4s, #0x0\n" + ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n" + "fmla v14.4s, v20.4s, v2.4s\n" + "movi v20.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x20]\n" + ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n" + ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n" + ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n" + ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x30]\n" + ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x40]\n" + ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n" + ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n" + ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n" + ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x50]\n" + ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x60]\n" + ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n" + ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n" + ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n" + ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x70]\n" + "add x23, x23, #0x88\n" + ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x0]\n" + ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n" + ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n" + ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n" + ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n" + "fmul v10.4s, v17.4s, v26.s[0]\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v11.4s, v9.4s, v10.4s\n" + "ldr q9, [x22, #0x10]\n" + "fmul v10.4s, v17.4s, v26.s[1]\n" + "fmla v13.4s, v29.4s, v10.4s\n" + "ldr d29, [x22, #-0x8]\n" + "fmul v10.4s, v17.4s, v26.s[2]\n" + "fmul v26.4s, v17.4s, v26.s[3]\n" + "fcvtl v29.4s, v29.4h\n" + "fmla v23.4s, v20.4s, v10.4s\n" + "movi v20.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "fmla v16.4s, v2.4s, v26.4s\n" + "movi v26.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" + ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x20]\n" + ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" + ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n" + ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x30]\n" + ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x40]\n" + ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n" + ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" + ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n" + ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x50]\n" + ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x60]\n" + ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n" + ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" + ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n" + ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x70]\n" + "add x22, x22, #0x88\n" + ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x21, #0x0]\n" + ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n" + ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n" + ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n" + "fmul v9.4s, v17.4s, v29.s[0]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v25.4s, v20.4s, v9.4s\n" + "ldr q9, [x21, #0x10]\n" + "fmul v20.4s, v17.4s, v29.s[1]\n" + "fmla v7.4s, v10.4s, v20.4s\n" + "ldr d20, [x21, #-0x8]\n" + "fmul v10.4s, v17.4s, v29.s[2]\n" + "fmul v29.4s, v17.4s, v29.s[3]\n" + "fcvtl v20.4s, v20.4h\n" + "fmla v0.4s, v26.4s, v10.4s\n" + "movi v26.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "fmla v4.4s, v2.4s, v29.4s\n" + "movi v2.4s, #0x0\n" + "movi v29.4s, #0x0\n" + ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" + ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n" + "ldr q12, [x21, #0x20]\n" + "fmul v24.4s, v17.4s, v20.s[0]\n" + ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" + ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n" + ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n" + "ldr q9, [x21, #0x30]\n" + "fmul v31.4s, v17.4s, v20.s[1]\n" + ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n" + ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n" + ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n" + ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n" + "ldr q12, [x21, #0x40]\n" + "fmul v6.4s, v17.4s, v20.s[2]\n" + "fmul v20.4s, v17.4s, v20.s[3]\n" + ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n" + ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" + ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n" + ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n" + "ldr q9, [x21, #0x50]\n" + ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n" + ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n" + ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n" + ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n" + "ldr q12, [x21, #0x60]\n" + ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n" + ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" + ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n" + ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n" + "ldr q17, [x21, #0x70]\n" + "add x21, x21, #0x88\n" + ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n" + ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n" + ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n" + ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n" + ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n" + ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n" + ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n" + ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "fmla v5.4s, v26.4s, v24.4s\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "fmla v21.4s, v10.4s, v31.4s\n" + "fmla v8.4s, v2.4s, v6.4s\n" + "fmla v1.4s, v29.4s, v20.4s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x27, x27, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "str q15, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q14, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q11, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q16, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q0, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q21, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q8, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q1, [x20, #0x0]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x10, x10, #0x10\n" + "cmp x10, #0x10\n" + "mov %x[res_ptr], x26\n" + "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x10, 9f\n" + "5:" // Row tail: Row loop + "add x24, %x[b_ptr], #0x8\n" + "mov x23, %x[nc]\n" + "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "add x25, %x[a_ptr], #0x8\n" + "mov x21, %x[nb]\n" + "movi v18.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "7:" // Row tail: Block loop + "ldr q7, [x24, #0x0]\n" + "ldr q5, [x25, #0x0]\n" + "movi v9.16b, #0x4\n" + "movi v4.4s, #0x0\n" + "ldr q3, [x24, #0x10]\n" + "ldr q2, [x25, #0x10]\n" + "movi v1.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "ldr q13, [x24, #0x20]\n" + "ldr q31, [x25, #0x20]\n" + "movi v30.4s, #0x0\n" + "movi v29.16b, #0xf0\n" + "ldr q28, [x24, #0x30]\n" + "ldr q27, [x25, #0x30]\n" + "sshl v20.16b, v7.16b, v9.16b\n" + "sub x20, x24, #0x8\n" + "ldr q26, [x25, #0x40]\n" + "ldr q25, [x25, #0x50]\n" + "sshl v17.16b, v3.16b, v9.16b\n" + "and v7.16b, v7.16b, v29.16b\n" + "ldr q24, [x25, #0x60]\n" + "ldr q16, [x25, #0x70]\n" + "sshl v22.16b, v13.16b, v9.16b\n" + "and v3.16b, v3.16b, v29.16b\n" + "ldr d21, [x20, #0x0]\n" + "ldr d12, [x25, #-0x8]\n" + ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n" + ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n" + ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n" + ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n" + "sshl v9.16b, v28.16b, v9.16b\n" + "subs x21, x21, #0x1\n" + "and v13.16b, v13.16b, v29.16b\n" + "and v28.16b, v28.16b, v29.16b\n" + "add x25, x25, #0x88\n" + "add x24, x24, #0x48\n" + "fcvtl v21.4s, v21.4h\n" + "fcvtl v12.4s, v12.4h\n" + ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n" + ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n" + ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n" + ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n" + "fmul v11.4s, v21.4s, v12.s[0]\n" + "fmul v23.4s, v21.4s, v12.s[1]\n" + "fmul v17.4s, v21.4s, v12.s[2]\n" + ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n" + "fmul v6.4s, v21.4s, v12.s[3]\n" + ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n" + ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n" + ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n" + ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n" + ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n" + ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n" + ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n" + ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n" + ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n" + ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n" + ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n" + ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n" + ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n" + ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n" + ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n" + ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n" + ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n" + ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n" + ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n" + ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n" + ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n" + ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n" + ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n" + "scvtf v4.4s, v4.4s, #0x4\n" + "scvtf v1.4s, v1.4s, #0x4\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "fmla v15.4s, v4.4s, v11.4s\n" + "scvtf v30.4s, v30.4s, #0x4\n" + "fmla v19.4s, v1.4s, v23.4s\n" + "fmla v18.4s, v0.4s, v17.4s\n" + "fmla v14.4s, v30.4s, v6.4s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x10, #0x1\n" + "str q15, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x2\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x3\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "str q14, [x20, #0x0]\n" + "8:" // Row tail: Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "bne 6b\n" + "subs x10, x10, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x9\n" + "mov %x[res_ptr], x22\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +#else + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * MLLM_FP16_TO_FP32(b_ptr[l].d[j]) * MLLM_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +#endif +} + +void mllm_gemm_q4_0_4x8_q8_0(int n, float * __restrict s, size_t bs, const void * __restrict vx, const void * __restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + (void)s; + (void)bs; + (void)vx; + (void)vy; + (void)nr; + (void)nc; + (void)nb; + (void)ncols_interleaved; + (void)blocklen; + +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntw() == 8) { + assert(!(mllm_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x10, %x[nr]\n" + "mov x9, #0x88\n" + "cmp x10, #0x10\n" + "mul x9, %x[nb], x9\n" + "blt 4f\n" + "1:" // Row loop + "add x28, %x[b_ptr], #0x8\n" + "mov x27, %x[nc]\n" + "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x25, %x[a_ptr], #0x8\n" + "movi v2.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "mov x24, %x[nb]\n" + "add x23, x25, x9\n" + "movi v12.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "add x22, x23, x9\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x21, x22, x9\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "movi v6.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "3:" // Block loop + "ldr q21, [x28, #0x0]\n" + "ldr q16, [x28, #0x10]\n" + "movi v1.16b, #0x4\n" + "movi v19.4s, #0x0\n" + "ldr q27, [x25, #0x0]\n" + "ldr q15, [x25, #0x10]\n" + "movi v26.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "ldr q29, [x28, #0x20]\n" + "ldr q3, [x28, #0x30]\n" + "movi v17.4s, #0x0\n" + "movi v0.16b, #0xf0\n" + "ldr d20, [x25, #-0x8]\n" + "ldr d9, [x23, #-0x8]\n" + "sshl v8.16b, v21.16b, v1.16b\n" + "sshl v31.16b, v16.16b, v1.16b\n" + "and v21.16b, v21.16b, v0.16b\n" + "and v16.16b, v16.16b, v0.16b\n" + "sub x20, x28, #0x8\n" + "subs x24, x24, #0x1\n" + "add x28, x28, #0x48\n" + ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n" + ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n" + "ldr q27, [x25, #0x20]\n" + ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n" + ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n" + "sshl v15.16b, v29.16b, v1.16b\n" + "sshl v1.16b, v3.16b, v1.16b\n" + "and v29.16b, v29.16b, v0.16b\n" + "and v3.16b, v3.16b, v0.16b\n" + "ldr q0, [x25, #0x30]\n" + "fcvtl v20.4s, v20.4h\n" + ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n" + "fcvtl v9.4s, v9.4h\n" + ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n" + "ldr q27, [x25, #0x40]\n" + ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n" + ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n" + "ldr q0, [x25, #0x50]\n" + ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n" + ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n" + "ldr q27, [x25, #0x60]\n" + ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n" + ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n" + "ldr q0, [x25, #0x70]\n" + "add x25, x25, #0x88\n" + ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n" + ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n" + "ldr d27, [x20, #0x0]\n" + ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n" + ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n" + "fcvtl v27.4s, v27.4h\n" + "uzp1 v0.2d, v19.2d, v26.2d\n" + "uzp2 v26.2d, v19.2d, v26.2d\n" + "fmul v19.4s, v27.4s, v20.s[0]\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "fmla v2.4s, v0.4s, v19.4s\n" + "ldr q19, [x23, #0x0]\n" + "uzp1 v0.2d, v18.2d, v17.2d\n" + "uzp2 v18.2d, v18.2d, v17.2d\n" + "fmul v17.4s, v27.4s, v20.s[1]\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v10.4s, v26.4s, v17.4s\n" + "ldr q17, [x23, #0x10]\n" + "fmul v26.4s, v27.4s, v20.s[2]\n" + "fmul v20.4s, v27.4s, v20.s[3]\n" + "fmla v12.4s, v0.4s, v26.4s\n" + "ldr d0, [x22, #-0x8]\n" + "ldr d26, [x21, #-0x8]\n" + "fcvtl v0.4s, v0.4h\n" + "fmla v28.4s, v18.4s, v20.4s\n" + "movi v20.4s, #0x0\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" + ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" + "ldr q19, [x23, #0x20]\n" + "fcvtl v26.4s, v26.4h\n" + ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" + ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" + "ldr q19, [x23, #0x40]\n" + ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" + ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" + "ldr q19, [x23, #0x60]\n" + ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n" + ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n" + "uzp1 v19.2d, v20.2d, v18.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp2 v20.2d, v20.2d, v18.2d\n" + "fmul v18.4s, v27.4s, v9.s[0]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v11.4s, v19.4s, v18.4s\n" + "ldr q18, [x22, #0x0]\n" + "fmul v19.4s, v27.4s, v9.s[1]\n" + "fmla v13.4s, v20.4s, v19.4s\n" + "movi v19.4s, #0x0\n" + "movi v20.4s, #0x0\n" + ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n" + ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n" + "ldr q17, [x23, #0x30]\n" + ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n" + ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n" + "ldr q17, [x23, #0x50]\n" + ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n" + ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n" + "ldr q17, [x23, #0x70]\n" + "add x23, x23, #0x88\n" + ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n" + ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n" + "uzp1 v17.2d, v19.2d, v20.2d\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "uzp2 v20.2d, v19.2d, v20.2d\n" + "fmul v19.4s, v27.4s, v9.s[2]\n" + "fmul v9.4s, v27.4s, v9.s[3]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v22.4s, v17.4s, v19.4s\n" + "ldr q17, [x22, #0x10]\n" + "movi v19.4s, #0x0\n" + ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n" + "fmla v23.4s, v20.4s, v9.4s\n" + "movi v20.4s, #0x0\n" + "movi v9.4s, #0x0\n" + ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n" + "ldr q18, [x22, #0x20]\n" + ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" + ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n" + ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n" + "ldr q18, [x22, #0x40]\n" + ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n" + ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n" + "ldr q18, [x22, #0x60]\n" + ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n" + ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n" + "ldr q17, [x22, #0x30]\n" + ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" + ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n" + "ldr q17, [x22, #0x50]\n" + ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n" + ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n" + "ldr q17, [x22, #0x70]\n" + "add x22, x22, #0x88\n" + ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n" + ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n" + "uzp1 v17.2d, v19.2d, v20.2d\n" + "uzp2 v20.2d, v19.2d, v20.2d\n" + "fmul v19.4s, v27.4s, v0.s[0]\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v25.4s, v17.4s, v19.4s\n" + "ldr q19, [x21, #0x0]\n" + "fmul v17.4s, v27.4s, v0.s[1]\n" + "fmla v5.4s, v20.4s, v17.4s\n" + "ldr q17, [x21, #0x10]\n" + "uzp1 v20.2d, v9.2d, v18.2d\n" + "uzp2 v9.2d, v9.2d, v18.2d\n" + "fmul v18.4s, v27.4s, v0.s[2]\n" + "fmul v0.4s, v27.4s, v0.s[3]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "fmla v7.4s, v20.4s, v18.4s\n" + "movi v20.4s, #0x0\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" + ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" + "ldr q19, [x21, #0x20]\n" + "fmla v4.4s, v9.4s, v0.4s\n" + "movi v9.4s, #0x0\n" + "movi v0.4s, #0x0\n" + ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" + "fmul v8.4s, v27.4s, v26.s[0]\n" + ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n" + "ldr q17, [x21, #0x30]\n" + ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" + "fmul v31.4s, v27.4s, v26.s[1]\n" + ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" + "ldr q19, [x21, #0x40]\n" + ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" + "fmul v15.4s, v27.4s, v26.s[2]\n" + "fmul v27.4s, v27.4s, v26.s[3]\n" + ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n" + "ldr q1, [x21, #0x50]\n" + ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" + ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" + "ldr q26, [x21, #0x60]\n" + ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n" + ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n" + "ldr q21, [x21, #0x70]\n" + "add x21, x21, #0x88\n" + ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n" + ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n" + ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n" + ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n" + "uzp1 v29.2d, v20.2d, v18.2d\n" + "uzp2 v21.2d, v20.2d, v18.2d\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "uzp1 v18.2d, v9.2d, v0.2d\n" + "uzp2 v16.2d, v9.2d, v0.2d\n" + "scvtf v21.4s, v21.4s, #0x4\n" + "fmla v6.4s, v29.4s, v8.4s\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v30.4s, v21.4s, v31.4s\n" + "fmla v24.4s, v18.4s, v15.4s\n" + "fmla v14.4s, v16.4s, v27.4s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x27, x27, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q28, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q11, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q6, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q24, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q14, [x20, #0x0]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x10, x10, #0x10\n" + "cmp x10, #0x10\n" + "mov %x[res_ptr], x26\n" + "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x10, 9f\n" + "5:" // Row tail: Row loop + "add x24, %x[b_ptr], #0x8\n" + "mov x23, %x[nc]\n" + "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "movi v2.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "add x25, %x[a_ptr], #0x8\n" + "mov x21, %x[nb]\n" + "movi v12.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "7:" // Row tail: Block loop + "ldr q6, [x24, #0x0]\n" + "ldr q5, [x24, #0x10]\n" + "movi v17.16b, #0x4\n" + "movi v8.4s, #0x0\n" + "ldr q4, [x25, #0x0]\n" + "ldr q13, [x25, #0x10]\n" + "movi v27.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "ldr q31, [x24, #0x20]\n" + "ldr q14, [x24, #0x30]\n" + "movi v29.4s, #0x0\n" + "movi v22.16b, #0xf0\n" + "ldr q11, [x25, #0x20]\n" + "ldr q23, [x25, #0x30]\n" + "sshl v21.16b, v6.16b, v17.16b\n" + "sshl v16.16b, v5.16b, v17.16b\n" + "ldr q20, [x25, #0x40]\n" + "ldr q26, [x25, #0x50]\n" + "and v6.16b, v6.16b, v22.16b\n" + "and v5.16b, v5.16b, v22.16b\n" + "ldr q25, [x25, #0x60]\n" + "ldr q3, [x25, #0x70]\n" + "sshl v19.16b, v31.16b, v17.16b\n" + "sshl v18.16b, v14.16b, v17.16b\n" + "ldr d17, [x25, #-0x8]\n" + ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n" + ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n" + "and v31.16b, v31.16b, v22.16b\n" + ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n" + ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n" + "and v14.16b, v14.16b, v22.16b\n" + "sub x20, x24, #0x8\n" + "ldr d16, [x20, #0x0]\n" + "subs x21, x21, #0x1\n" + "add x25, x25, #0x88\n" + "fcvtl v17.4s, v17.4h\n" + "add x24, x24, #0x48\n" + ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n" + ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n" + ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n" + ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n" + "fcvtl v16.4s, v16.4h\n" + ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n" + ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n" + "fmul v23.4s, v16.4s, v17.s[0]\n" + "fmul v21.4s, v16.4s, v17.s[1]\n" + "fmul v1.4s, v16.4s, v17.s[2]\n" + "fmul v20.4s, v16.4s, v17.s[3]\n" + ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n" + ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n" + ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n" + ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n" + ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n" + ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n" + "uzp1 v19.2d, v8.2d, v27.2d\n" + "uzp2 v18.2d, v8.2d, v27.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp1 v17.2d, v0.2d, v29.2d\n" + "uzp2 v16.2d, v0.2d, v29.2d\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v2.4s, v19.4s, v23.4s\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v10.4s, v18.4s, v21.4s\n" + "fmla v12.4s, v17.4s, v1.4s\n" + "fmla v28.4s, v16.4s, v20.4s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x10, #0x1\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x2\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x3\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "str q28, [x20, #0x0]\n" + "8:" // Row tail: Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "bne 6b\n" + "subs x10, x10, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x9\n" + "mov %x[res_ptr], x22\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +#elif defined(__ARM_NEON) && defined(__aarch64__) + assert((mllm_cpu_has_sve() || mllm_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " + "performance"); +#else + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * MLLM_FP16_TO_FP32(b_ptr[l].d[j]) * MLLM_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +#endif +} + +void mllm_gemm_q4_0_8x8_q8_0(int n, float * __restrict s, size_t bs, const void * __restrict vx, const void * __restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + (void)s; + (void)bs; + (void)vx; + (void)vy; + (void)nr; + (void)nc; + (void)nb; + (void)ncols_interleaved; + (void)blocklen; + +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntw() == 8) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x20, #0x4\n" + "mov x13, %x[nr]\n" + "mov z28.s, #-0x4\n" + "mov x12, #0x88\n" + "ptrue p1.b\n" + "whilelt p0.s, XZR, x20\n" + "cmp x13, #0x10\n" + "mul x12, %x[nb], x12\n" + "blt 4f\n" + "1:" // Row loop + "add x11, %x[b_ptr], #0x10\n" + "mov x10, %x[nc]\n" + "add x9, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x28, %x[a_ptr], #0x8\n" + "mov z24.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov x27, %x[nb]\n" + "add x26, x28, x12\n" + "mov z12.b, #0x0\n" + "mov z0.b, #0x0\n" + "add x25, x26, x12\n" + "mov z13.b, #0x0\n" + "mov z1.b, #0x0\n" + "add x24, x25, x12\n" + "mov z20.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z8.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z10.b, #0x0\n" + "3:" // Block loop + "ld1b { z30.b }, p1/Z, [x11]\n" + "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n" + "mov z18.s, #0x0\n" + "mov z7.s, #0x0\n" + "ld1rqb { z3.b }, p1/Z, [x28]\n" + "ld1rqb { z5.b }, p1/Z, [x28, #16]\n" + "mov z9.s, #0x0\n" + "mov z22.s, #0x0\n" + "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n" + "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n" + "sub x20, x11, #0x10\n" + "sub x23, x28, #0x8\n" + "lsl z31.b, z30.b, #0x4\n" + "lsl z6.b, z21.b, #0x4\n" + "ld1h { z23.s }, p1/Z, [x20]\n" + "sub x22, x26, #0x8\n" + "and z30.b, z30.b, #0xf0\n" + "and z21.b, z21.b, #0xf0\n" + "sub x21, x25, #0x8\n" + "sub x20, x24, #0x8\n" + "lsl z14.b, z4.b, #0x4\n" + "lsl z2.b, z17.b, #0x4\n" + "subs x27, x27, #0x1\n" + "add x11, x11, #0x90\n" + ".inst 0x451f9872 // smmla z18.s, z3.b, z31.b\n" + ".inst 0x45069867 // smmla z7.s, z3.b, z6.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #32]\n" + "and z4.b, z4.b, #0xf0\n" + ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" + ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #48]\n" + "and z17.b, z17.b, #0xf0\n" + "fcvt z23.s, p1/m, z23.h\n" + ".inst 0x450e9872 // smmla z18.s, z3.b, z14.b\n" + ".inst 0x45029867 // smmla z7.s, z3.b, z2.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #64]\n" + ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" + ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #80]\n" + "fscale z23.s, p1/m, z23.s, z28.s\n" + ".inst 0x451e9872 // smmla z18.s, z3.b, z30.b\n" + ".inst 0x45159867 // smmla z7.s, z3.b, z21.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #96]\n" + ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" + ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #112]\n" + "add x28, x28, #0x88\n" + ".inst 0x45049872 // smmla z18.s, z3.b, z4.b\n" + ".inst 0x45119867 // smmla z7.s, z3.b, z17.b\n" + "ld1h { z3.s }, p0/Z, [x23]\n" + ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" + ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" + "fcvt z3.s, p1/m, z3.h\n" + "uzp1 z5.d, z18.d, z7.d\n" + "uzp2 z18.d, z18.d, z7.d\n" + "mov z3.q, z3.q[0]\n" + "uzp1 z7.d, z9.d, z22.d\n" + "uzp2 z22.d, z9.d, z22.d\n" + "fmul z9.s, z23.s, z3.s[0]\n" + "scvtf z5.s, p1/m, z5.s\n" + "scvtf z18.s, p1/m, z18.s\n" + "scvtf z7.s, p1/m, z7.s\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z24.s, p1/M, z5.s, z9.s\n" + "ld1rqb { z5.b }, p1/Z, [x26]\n" + "fmul z9.s, z23.s, z3.s[1]\n" + "fmla z15.s, p1/M, z18.s, z9.s\n" + "ld1rqb { z18.b }, p1/Z, [x26, #16]\n" + "fmul z9.s, z23.s, z3.s[2]\n" + "fmul z3.s, z23.s, z3.s[3]\n" + "fmla z12.s, p1/M, z7.s, z9.s\n" + "mov z9.s, #0x0\n" + "ld1h { z7.s }, p0/Z, [x22]\n" + ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" + "fmla z0.s, p1/M, z22.s, z3.s\n" + "mov z22.s, #0x0\n" + "ld1h { z3.s }, p0/Z, [x21]\n" + ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #32]\n" + "fcvt z7.s, p1/m, z7.h\n" + "fcvt z3.s, p1/m, z3.h\n" + ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" + ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #64]\n" + "mov z7.q, z7.q[0]\n" + "mov z3.q, z3.q[0]\n" + ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" + ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #96]\n" + ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" + ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" + "uzp1 z5.d, z9.d, z22.d\n" + "scvtf z5.s, p1/m, z5.s\n" + "uzp2 z22.d, z9.d, z22.d\n" + "fmul z9.s, z23.s, z7.s[0]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z13.s, p1/M, z5.s, z9.s\n" + "ld1rqb { z9.b }, p1/Z, [x25]\n" + "fmul z5.s, z23.s, z7.s[1]\n" + "fmla z1.s, p1/M, z22.s, z5.s\n" + "mov z5.s, #0x0\n" + "mov z22.s, #0x0\n" + ".inst 0x451f9a45 // smmla z5.s, z18.b, z31.b\n" + ".inst 0x45069a56 // smmla z22.s, z18.b, z6.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #48]\n" + ".inst 0x450e9a45 // smmla z5.s, z18.b, z14.b\n" + ".inst 0x45029a56 // smmla z22.s, z18.b, z2.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #80]\n" + ".inst 0x451e9a45 // smmla z5.s, z18.b, z30.b\n" + ".inst 0x45159a56 // smmla z22.s, z18.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #112]\n" + "add x26, x26, #0x88\n" + ".inst 0x45049a45 // smmla z5.s, z18.b, z4.b\n" + ".inst 0x45119a56 // smmla z22.s, z18.b, z17.b\n" + "uzp1 z18.d, z5.d, z22.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp2 z22.d, z5.d, z22.d\n" + "fmul z5.s, z23.s, z7.s[2]\n" + "fmul z7.s, z23.s, z7.s[3]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z20.s, p1/M, z18.s, z5.s\n" + "ld1rqb { z18.b }, p1/Z, [x25, #16]\n" + "ld1h { z5.s }, p0/Z, [x20]\n" + "fcvt z5.s, p1/m, z5.h\n" + "fmla z25.s, p1/M, z22.s, z7.s\n" + "mov z22.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9936 // smmla z22.s, z9.b, z31.b\n" + ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #32]\n" + "mov z5.q, z5.q[0]\n" + ".inst 0x450e9936 // smmla z22.s, z9.b, z14.b\n" + ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #64]\n" + ".inst 0x451e9936 // smmla z22.s, z9.b, z30.b\n" + ".inst 0x45159927 // smmla z7.s, z9.b, z21.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #96]\n" + ".inst 0x45049936 // smmla z22.s, z9.b, z4.b\n" + ".inst 0x45119927 // smmla z7.s, z9.b, z17.b\n" + "uzp1 z9.d, z22.d, z7.d\n" + "scvtf z9.s, p1/m, z9.s\n" + "uzp2 z22.d, z22.d, z7.d\n" + "fmul z7.s, z23.s, z3.s[0]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z11.s, p1/M, z9.s, z7.s\n" + "ld1rqb { z9.b }, p1/Z, [x24]\n" + "fmul z7.s, z23.s, z3.s[1]\n" + "fmla z16.s, p1/M, z22.s, z7.s\n" + "mov z22.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9a56 // smmla z22.s, z18.b, z31.b\n" + ".inst 0x45069a47 // smmla z7.s, z18.b, z6.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #48]\n" + ".inst 0x450e9a56 // smmla z22.s, z18.b, z14.b\n" + ".inst 0x45029a47 // smmla z7.s, z18.b, z2.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #80]\n" + ".inst 0x451e9a56 // smmla z22.s, z18.b, z30.b\n" + ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #112]\n" + "add x25, x25, #0x88\n" + ".inst 0x45049a56 // smmla z22.s, z18.b, z4.b\n" + ".inst 0x45119a47 // smmla z7.s, z18.b, z17.b\n" + "uzp1 z18.d, z22.d, z7.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp2 z7.d, z22.d, z7.d\n" + "fmul z22.s, z23.s, z3.s[2]\n" + "fmul z3.s, z23.s, z3.s[3]\n" + "scvtf z7.s, p1/m, z7.s\n" + "fmla z19.s, p1/M, z18.s, z22.s\n" + "ld1rqb { z18.b }, p1/Z, [x24, #16]\n" + "fmul z22.s, z23.s, z5.s[0]\n" + "fmla z26.s, p1/M, z7.s, z3.s\n" + "mov z3.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9923 // smmla z3.s, z9.b, z31.b\n" + ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" + "ld1rqb { z9.b }, p1/Z, [x24, #32]\n" + ".inst 0x450e9923 // smmla z3.s, z9.b, z14.b\n" + ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" + "mov z9.s, #0x0\n" + ".inst 0x451f9a49 // smmla z9.s, z18.b, z31.b\n" + "mov z31.s, #0x0\n" + ".inst 0x45069a5f // smmla z31.s, z18.b, z6.b\n" + "ld1rqb { z6.b }, p1/Z, [x24, #48]\n" + "ld1rqb { z18.b }, p1/Z, [x24, #64]\n" + ".inst 0x450e98c9 // smmla z9.s, z6.b, z14.b\n" + "fmul z14.s, z23.s, z5.s[1]\n" + ".inst 0x450298df // smmla z31.s, z6.b, z2.b\n" + "ld1rqb { z6.b }, p1/Z, [x24, #80]\n" + "fmul z2.s, z23.s, z5.s[2]\n" + "fmul z23.s, z23.s, z5.s[3]\n" + ".inst 0x451e9a43 // smmla z3.s, z18.b, z30.b\n" + ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x24, #96]\n" + ".inst 0x451e98c9 // smmla z9.s, z6.b, z30.b\n" + ".inst 0x451598df // smmla z31.s, z6.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x24, #112]\n" + "add x24, x24, #0x88\n" + ".inst 0x450498a3 // smmla z3.s, z5.b, z4.b\n" + ".inst 0x451198a7 // smmla z7.s, z5.b, z17.b\n" + ".inst 0x45049a49 // smmla z9.s, z18.b, z4.b\n" + ".inst 0x45119a5f // smmla z31.s, z18.b, z17.b\n" + "uzp1 z18.d, z3.d, z7.d\n" + "uzp2 z5.d, z3.d, z7.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp1 z6.d, z9.d, z31.d\n" + "uzp2 z9.d, z9.d, z31.d\n" + "scvtf z5.s, p1/m, z5.s\n" + "fmla z8.s, p1/M, z18.s, z22.s\n" + "scvtf z6.s, p1/m, z6.s\n" + "scvtf z9.s, p1/m, z9.s\n" + "fmla z29.s, p1/M, z5.s, z14.s\n" + "fmla z27.s, p1/M, z6.s, z2.s\n" + "fmla z10.s, p1/M, z9.s, z23.s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x10, x10, #0x8\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "st1w { z24.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z15.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z12.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z0.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z13.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z1.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z20.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z25.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z11.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z16.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z19.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z26.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z8.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z29.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z27.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z10.s }, p1, [x20]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x13, x13, #0x10\n" + "cmp x13, #0x10\n" + "mov %x[res_ptr], x9\n" + "madd %x[a_ptr], x20, x12, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x13, 9f\n" + "5:" // Row tail: Row loop + "add x25, %x[b_ptr], #0x10\n" + "mov x24, %x[nc]\n" + "add x23, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "mov z24.b, #0x0\n" + "mov z15.b, #0x0\n" + "add x28, %x[a_ptr], #0x8\n" + "mov x22, %x[nb]\n" + "mov z12.b, #0x0\n" + "mov z0.b, #0x0\n" + "7:" // Row tail: Block loop + "ld1b { z3.b }, p1/Z, [x25]\n" + "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n" + "mov z2.s, #0x0\n" + "mov z25.s, #0x0\n" + "ld1rqb { z26.b }, p1/Z, [x28]\n" + "ld1rqb { z21.b }, p1/Z, [x28, #16]\n" + "mov z27.s, #0x0\n" + "mov z19.s, #0x0\n" + "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n" + "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n" + "sub x21, x25, #0x10\n" + "sub x20, x28, #0x8\n" + "lsl z20.b, z3.b, #0x4\n" + "lsl z4.b, z6.b, #0x4\n" + "ld1rqb { z10.b }, p1/Z, [x28, #32]\n" + "ld1rqb { z23.b }, p1/Z, [x28, #48]\n" + "and z3.b, z3.b, #0xf0\n" + "and z6.b, z6.b, #0xf0\n" + "ld1rqb { z11.b }, p1/Z, [x28, #64]\n" + "ld1rqb { z7.b }, p1/Z, [x28, #80]\n" + "lsl z8.b, z29.b, #0x4\n" + "lsl z14.b, z16.b, #0x4\n" + "ld1rqb { z18.b }, p1/Z, [x28, #96]\n" + "ld1rqb { z30.b }, p1/Z, [x28, #112]\n" + ".inst 0x45149b42 // smmla z2.s, z26.b, z20.b\n" + ".inst 0x45049b59 // smmla z25.s, z26.b, z4.b\n" + "and z29.b, z29.b, #0xf0\n" + "ld1h { z17.s }, p1/Z, [x21]\n" + ".inst 0x45149abb // smmla z27.s, z21.b, z20.b\n" + ".inst 0x45049ab3 // smmla z19.s, z21.b, z4.b\n" + "and z16.b, z16.b, #0xf0\n" + "ld1h { z4.s }, p0/Z, [x20]\n" + "subs x22, x22, #0x1\n" + "add x28, x28, #0x88\n" + "fcvt z17.s, p1/m, z17.h\n" + "add x25, x25, #0x90\n" + ".inst 0x45089942 // smmla z2.s, z10.b, z8.b\n" + ".inst 0x450e9959 // smmla z25.s, z10.b, z14.b\n" + "fcvt z4.s, p1/m, z4.h\n" + ".inst 0x45089afb // smmla z27.s, z23.b, z8.b\n" + ".inst 0x450e9af3 // smmla z19.s, z23.b, z14.b\n" + "fscale z17.s, p1/m, z17.s, z28.s\n" + "mov z4.q, z4.q[0]\n" + ".inst 0x45039962 // smmla z2.s, z11.b, z3.b\n" + ".inst 0x45069979 // smmla z25.s, z11.b, z6.b\n" + "fmul z23.s, z17.s, z4.s[0]\n" + "fmul z9.s, z17.s, z4.s[1]\n" + "fmul z21.s, z17.s, z4.s[2]\n" + "fmul z4.s, z17.s, z4.s[3]\n" + ".inst 0x450398fb // smmla z27.s, z7.b, z3.b\n" + ".inst 0x450698f3 // smmla z19.s, z7.b, z6.b\n" + ".inst 0x451d9a42 // smmla z2.s, z18.b, z29.b\n" + ".inst 0x45109a59 // smmla z25.s, z18.b, z16.b\n" + ".inst 0x451d9bdb // smmla z27.s, z30.b, z29.b\n" + ".inst 0x45109bd3 // smmla z19.s, z30.b, z16.b\n" + "uzp1 z31.d, z2.d, z25.d\n" + "uzp2 z13.d, z2.d, z25.d\n" + "scvtf z31.s, p1/m, z31.s\n" + "uzp1 z17.d, z27.d, z19.d\n" + "uzp2 z18.d, z27.d, z19.d\n" + "scvtf z13.s, p1/m, z13.s\n" + "fmla z24.s, p1/M, z31.s, z23.s\n" + "scvtf z17.s, p1/m, z17.s\n" + "scvtf z18.s, p1/m, z18.s\n" + "fmla z15.s, p1/M, z13.s, z9.s\n" + "fmla z12.s, p1/M, z17.s, z21.s\n" + "fmla z0.s, p1/M, z18.s, z4.s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x13, #0x1\n" + "st1w { z24.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x13, #0x2\n" + "st1w { z15.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x13, #0x3\n" + "st1w { z12.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "st1w { z0.s }, p1, [x20]\n" + "8:" // Row tail: Accumulator store skip + "subs x24, x24, #0x8\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "bne 6b\n" + "subs x13, x13, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x12\n" + "mov %x[res_ptr], x23\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); + return; + } + else if (mllm_cpu_has_neon() && mllm_cpu_has_matmul_int8()) { + assert((mllm_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal " + "performance"); + } + else if (mllm_cpu_has_neon()) { + assert(((mllm_cpu_has_sve() && (svcntw() == 8)) || mllm_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 " + "quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + assert(mllm_cpu_has_sve() && + "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance"); +#elif defined(__ARM_NEON) && defined(__aarch64__) + assert((mllm_cpu_has_sve() || mllm_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " + "performance"); +#else + float sumf[4][8]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * MLLM_FP16_TO_FP32(b_ptr[l].d[j]) * MLLM_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +#endif +} + + +void quantize_row_q4_0_4x4(const float * __restrict x, void * __restrict y, int k){ + assert(k%QK4_0 == 0); + std::cout<<"Quantize 4x4:"< +#include "SGEMM.hpp" #define ASSERT(x) \ do { \ @@ -169,6 +171,199 @@ ErrorCode mat_mul_sparse(Tensor *x, Tensor *W, Tensor *dst, int thread_count){ return MLLM_NO_ERROR; } +ErrorCode mat_mul(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Tensor *bias, bool transpose0, bool transpose1, int thread_count) { + // src1 = W src0 = x + // transpose0=false transpose1=true + const int M = transpose0 ? src0->dimension() : src0->sequence(); + const int K = transpose0 ? src0->sequence() : src0->dimension(); + const int N = transpose1 ? src1->sequence() : src1->dimension(); + + auto src0_dtype = src0->dtype(); + auto src1_dtype = src1->dtype(); + auto vec_dot_type = type_traits[src1_dtype].vec_dot_type; + auto vec_dot = type_traits[src1_dtype].vec_dot; + auto x_to_vec_dot_type = type_traits[vec_dot_type].from_float; + auto from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat; + mllm_gemv_func const gemv = type_traits[src1_dtype].gemv; + mllm_gemm_func const gemm = type_traits[src1_dtype].gemm; + auto blck_size_interleave = type_traits[src1_dtype].blck_size_interleave; + + auto src1_type_size = type_size(src1_dtype); + auto src1_blck_size = blck_size(src1_dtype); + auto src0_type_size = type_size(src0->dtype()); + auto src0_blck_size = blck_size(src0->dtype()); +#ifdef LLAMAFILE_SGEMM + if (check_llamafile_sgemm(N, M, K/blck_size(src0->dtype()),src1->dtype(),src0->dtype(),dst->dtype())&&!support_bias){ + const int ld_src1 = src1->sequence_skip_dim(); + const int ld_src0 = src0->sequence_skip_dim(); + const int ld_dst = dst->sequence_skip_dim(); + int is_0 = (src1->batch() == 1 && src1->head() == 1&&src1->batch()!=src0->batch()) ? 0 : 1; +#pragma omp parallel for collapse(3) num_threads(thread_count) + for (int64_t b = 0; b < dst->batch(); b++){ + for (int64_t h = 0; h < dst->head(); h++){ + for (int id = 0; id < thread_count; id++){ + llamafile_sgemm(N, M, K/blck_size(src0->dtype()), + (char *)src1->rawHostPtr() + src1->offset(b*is_0, h*is_0, 0, 0) * src1_type_size / src1_blck_size, + ld_src1 / src1_blck_size, + (char *)src0->rawHostPtr() + src0->offset(b, h, 0, 0) * src0_type_size / src0_blck_size, + ld_src0/ src0_blck_size, + (char *)dst->rawHostPtr() + dst->offset(b, h, 0, 0) * type_size(dst->dtype()) / blck_size(dst->dtype()), + ld_dst/blck_size(dst->dtype()), + id, thread_count, + src1->dtype(), + src0->dtype(), + dst->dtype()); + } + } + } + return MLLM_NO_ERROR; + } +#endif + auto not_vec_dot_type = src0_dtype != vec_dot_type; + std::unique_ptr to; // later this tensor will be freed by ~Tensor + if(not_vec_dot_type){ + // convert x.dtype to vec_dot_type + // so that we can use vec_dot to calculate dot product + ASSERT(src0_dtype == MLLM_TYPE_F32); // x should be fp32 + to = std::make_unique(src0->shape()); + to->setBackend(src0->backend()); + to->setDtype(vec_dot_type); + to->alloc(); +// void *row_src = src0->rawHostPtr(); +// void *row_dst = to->rawHostPtr(); +// auto row_size_src = row_size(src0_dtype, src0->dimension()); +// auto row_size_dst = row_size(vec_dot_type, to->dimension()); +// auto n_row = src0->batch() * src0->head() * src0->sequence(); +// auto n_ele = src0->dimension(); +// #pragma omp parallel for num_threads(thread_count) +// for(int i = 0;i < n_row;i++){ // copy row by row +// auto row1 = (char *)row_src + i * row_size_src; +// auto row2 = (char *)row_dst + i * row_size_dst; +// x_to_vec_dot_type(reinterpret_cast(row1), row2, n_ele); +// } + int64_t i_processed = 0; + if (from_float_to_mat && gemv && dst->masterTensor()==nullptr){ + for (int b = 0; b < src0->batch(); b++) { + for (int h = 0; h < src0->head(); h++) { +#pragma omp parallel for collapse(1) num_threads(thread_count) + for (int64_t s = 0; s < src0->sequence() - src0->sequence() % 4; s += 4) { + from_float_to_mat(src0->hostPtr() + src0->offset(b, h, s, 0), + (char *)to->rawHostPtr() + to->offset(b, h, s, 0) * type_size(to->dtype()) / blck_size(to->dtype()), + 4, src0->dimension(), blck_size_interleave); + } + i_processed = src0->sequence() - src0->sequence() % 4; + } + } + } +#pragma omp parallel for collapse(3) num_threads(thread_count) + for (int b = 0; b < src0->batch(); b++) { + for (int h = 0; h < src0->head(); h++) { + for (int s = i_processed; s < src0->sequence(); s++) { + x_to_vec_dot_type(src0->hostPtr() + src0->offset(b, h, s, 0), + (char *)to->rawHostPtr() + to->offset(b, h, s, 0) * type_size(to->dtype()) / blck_size(to->dtype()), + src0->dimension()); + } + } + } + src0 = to.get(); + src0_dtype = src0->dtype(); + src0_type_size = type_size(src0->dtype()); + src0_blck_size = blck_size(src0->dtype()); + } + +#ifdef LLAMAFILE_SGEMM + if (check_llamafile_sgemm(N, M, K/blck_size(src1->dtype()),src1->dtype(),src0->dtype(),dst->dtype())&&!support_bias){ + const int ld_src1 = src1->sequence_skip_dim(); + const int ld_src0 = src0->sequence_skip_dim(); + const int ld_dst = dst->sequence_skip_dim(); +#pragma omp parallel for collapse(3) num_threads(thread_count) + for (int64_t b = 0; b < dst->batch(); b++){ + for (int64_t h = 0; h < dst->head(); h++){ + for (int id = 0; id < thread_count; id++){ + llamafile_sgemm(N, M, K/blck_size(src1->dtype()), + (char *)src1->rawHostPtr() + src1->offset(b, h, 0, 0) * src1_type_size / src1_blck_size, + ld_src1 / src1_blck_size, + (char *)src0->rawHostPtr() + src0->offset(b, h, 0, 0) * src0_type_size / src0_blck_size, + ld_src0/ src0_blck_size, + (char *)dst->rawHostPtr() + dst->offset(b, h, 0, 0) * type_size(dst->dtype()) / blck_size(dst->dtype()), + ld_dst/blck_size(dst->dtype()), + id, thread_count, + src1->dtype(), + src0->dtype(), + dst->dtype()); + } + } + } + return MLLM_NO_ERROR; + } +#endif + + if(gemv&&!support_bias){ + int nth=thread_count; +#pragma omp parallel for collapse(1) num_threads(thread_count) + for (int ith = 0; ith < nth; ith++){ + int64_t i_processed = 0; + int64_t seq_start = (ith * N) / nth; + int64_t seq_end = ((ith + 1) * N) / nth; + if (gemm && (M > 3) && dst->masterTensor()==nullptr) { + gemm(K, dst->hostPtr() + dst->offset(0, 0, 0, seq_start), + N, (char *)src1->rawHostPtr()+ src1->offset(0, 0, seq_start, 0) * src1_type_size / src1_blck_size, + (char *)src0->rawHostPtr(), M - M % 4, N/nth); + i_processed = M - M % 4; + } + for (int iter = i_processed; iter < M; iter++) { //M-M%4 + gemv(K, dst->hostPtr() + dst->offset(0, 0, iter, seq_start), + N, (char *)src1->rawHostPtr()+ src1->offset(0, 0, seq_start, 0) * src1_type_size / src1_blck_size, + (char *)src0->rawHostPtr() + src0->offset(0, 0, iter, 0) * src0_type_size / src0_blck_size, + 1, N/nth); + } + } + return MLLM_NO_ERROR; + } + + Tensor *src0_cal = src0; + Tensor *src1_cal = src1; + const int64_t blck_0 = 16; + int is_0 = (src1->batch() == 1 && src1->head() == 1&&src1->batch()!=src0->batch()) ? 0 : 1; +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int b = 0; b < src0->batch(); b++) { + for (int h = 0; h < src0->head(); h++) { + for (int m = 0; m < M; m++) { + for (int block = 0; block < N / blck_0 + 1; block++) { + for (int n = block * blck_0; n < (block + 1) * blck_0 & n < N; n++) { + int s_1, d_1; + int s_0, d_0; + if (!transpose0 && transpose1) { + s_1 = n; d_1 = 0; s_0 = m; d_0 = 0; + } else if (!transpose0 && !transpose1) { + s_1 = 0; d_1 = n; s_0 = m; d_0 = 0; + } else { + s_1 = 0; d_1 = n; s_0 = 0; d_0 = m; + } + float tmp = 0; + vec_dot(K, &tmp, + (char *)src1_cal->rawHostPtr() + src1_cal->offset(b*is_0, h*is_0, s_1, d_1) * src1_type_size / src1_blck_size, + (char *)src0_cal->rawHostPtr() + src0_cal->offset(b, h, s_0, d_0) * src0_type_size / src0_blck_size); + if(dst->dtypeAt(b,h,m,n) == MLLM_TYPE_F32) { + dst->setDataAt(b, h, m, n, tmp); + if (support_bias) { + *dst->ptrAt(b, h, m, n) += bias->dataAt(0, 0, 0, n); + } + }else if(dst->dtypeAt(b,h,m,n) == MLLM_TYPE_F16) { + if (support_bias) { + *dst->ptrAt(b, h, m, n) = MLLM_FP32_TO_FP16(tmp + bias->dataAt(0, 0, 0, n)); + } else { + *dst->ptrAt(b, h, m, n) = MLLM_FP32_TO_FP16(tmp); + } + }else{std::cout<<"Not support type [Matmul]"<dimension() : src0->sequence(); const int K = transpose0 ? src0->sequence() : src0->dimension(); @@ -451,4 +646,126 @@ ErrorCode mat_mul_fp32_q6_K(Tensor *src0_, Tensor *src1, Tensor *dst, bool suppo } return MLLM_NO_ERROR; } +*/ + + +ErrorCode mat_mul_elastic(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Tensor *bias, + int activate_input_dim, int activate_output_dim, + bool transpose0, bool transpose1, int thread_count) { + // src1 = W src0 = x + // transpose0=false transpose1=true + const int M = transpose0 ? src0->dimension() : src0->sequence(); + const int K = transpose0 ? src0->sequence() : src0->dimension(); + const int N = transpose1 ? src1->sequence() : src1->dimension(); + + auto src0_dtype = src0->dtype(); + auto src1_dtype = src1->dtype(); + auto vec_dot_type = type_traits[src1_dtype].vec_dot_type; + auto vec_dot = type_traits[src1_dtype].vec_dot; + auto x_to_vec_dot_type = type_traits[vec_dot_type].from_float; + + auto src1_type_size = type_size(src1_dtype); + auto src1_blck_size = blck_size(src1_dtype); + auto src0_type_size = type_size(src0->dtype()); + auto src0_blck_size = blck_size(src0->dtype()); + + int use_N = (activate_output_dim == -1) ? N : activate_output_dim; + int use_K = (activate_input_dim == -1) ? K : activate_input_dim; + + if (check_llamafile_sgemm(use_N, M, use_K/blck_size(src0->dtype()),src1->dtype(),src0->dtype(),dst->dtype())){ + const int ld_src1 = src1->sequence_skip_dim(); + const int ld_src0 = src0->sequence_skip_dim(); + const int ld_dst = dst->sequence_skip_dim(); + int is_0 = (src1->batch() == 1 && src1->head() == 1) ? 0 : 1; +#pragma omp parallel for collapse(3) num_threads(thread_count) + for (int64_t b = 0; b < dst->batch(); b++){ + for (int64_t h = 0; h < dst->head(); h++){ + for (int id = 0; id < thread_count; id++){ + llamafile_sgemm(use_N, M, use_K/blck_size(src0->dtype()), + (char *)src1->rawHostPtr() + src1->offset(b*is_0, h*is_0, 0, 0) * src1_type_size / src1_blck_size, + ld_src1, + (char *)src0->rawHostPtr() + src0->offset(b, h, 0, 0) * src0_type_size / src0_blck_size, + ld_src0, + (char *)dst->rawHostPtr() + dst->offset(b, h, 0, 0) * type_size(dst->dtype()) / blck_size(dst->dtype()), + ld_dst, + id, thread_count, + src1->dtype(), + src0->dtype(), + dst->dtype()); + } + } + } + return MLLM_NO_ERROR; + } + + auto not_vec_dot_type = src0_dtype != vec_dot_type; + std::unique_ptr to; // later this tensor will be freed by ~Tensor + if(not_vec_dot_type){ + // convert x.dtype to vec_dot_type + // so that we can use vec_dot to calculate dot product + ASSERT(src0_dtype == MLLM_TYPE_F32); // x should be fp32 + to = std::make_unique(src0->shape()); + to->setBackend(src0->backend()); + to->setDtype(vec_dot_type); + to->alloc(); + void *row_src = src0->rawHostPtr(); + void *row_dst = to->rawHostPtr(); + auto row_size_src = row_size(src0_dtype, src0->dimension()); + auto row_size_dst = row_size(vec_dot_type, to->dimension()); + auto n_row = src0->batch() * src0->head() * src0->sequence(); + auto n_ele = src0->dimension(); +#pragma omp parallel for num_threads(thread_count) + for(int i = 0;i < n_row;i++){ // copy row by row + auto row1 = (char *)row_src + i * row_size_src; + auto row2 = (char *)row_dst + i * row_size_dst; + x_to_vec_dot_type(reinterpret_cast(row1), row2, n_ele); + } + src0 = to.get(); + src0_dtype = src0->dtype(); + src0_type_size = type_size(src0->dtype()); + src0_blck_size = blck_size(src0->dtype()); + } + + Tensor *src0_cal = src0; + Tensor *src1_cal = src1; + const int64_t blck_0 = 16; + int is_0 = (src1->batch() == 1 && src1->head() == 1) ? 0 : 1; +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int b = 0; b < src0->batch(); b++) { + for (int h = 0; h < src0->head(); h++) { + for (int m = 0; m < M; m++) { + for (int block = 0; block < use_N / blck_0 + 1; block++) { + for (int n = block * blck_0; n < (block + 1) * blck_0 & n < use_N; n++) { + int s_1, d_1; + int s_0, d_0; + if (!transpose0 && transpose1) { + s_1 = n; d_1 = 0; s_0 = m; d_0 = 0; + } else if (!transpose0 && !transpose1) { + s_1 = 0; d_1 = n; s_0 = m; d_0 = 0; + } else { + s_1 = 0; d_1 = n; s_0 = 0; d_0 = m; + } + float tmp = 0; + vec_dot(use_K, &tmp, + (char *)src1_cal->rawHostPtr() + src1_cal->offset(b*is_0, h*is_0, s_1, d_1) * src1_type_size / src1_blck_size, + (char *)src0_cal->rawHostPtr() + src0_cal->offset(b, h, s_0, d_0) * src0_type_size / src0_blck_size); + if(dst->dtypeAt(b,h,m,n) == MLLM_TYPE_F32) { + dst->setDataAt(b, h, m, n, tmp); + if (support_bias) { + *dst->ptrAt(b, h, m, n) += bias->dataAt(0, 0, 0, n); + } + }else if(dst->dtypeAt(b,h,m,n) == MLLM_TYPE_F16) { + if (support_bias) { + *dst->ptrAt(b, h, m, n) = MLLM_FP32_TO_FP16(tmp + bias->dataAt(0, 0, 0, n)); + } else { + *dst->ptrAt(b, h, m, n) = MLLM_FP32_TO_FP16(tmp); + } + }else{std::cout<<"Not support type [Matmul]"< - - -ErrorCode mat_mul_elastic_fp32(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Tensor *bias, - int activate_input_dim, int activate_output_dim, - bool transpose0, bool transpose1, int thread_count) { - const int M = transpose0 ? src0->dimension() : src0->sequence(); - const int K = transpose0 ? src0->sequence() : src0->dimension(); - const int N = transpose1 ? src1->sequence() : src1->dimension(); - int use_N = (activate_output_dim == -1) ? N : activate_output_dim; - int use_K = (activate_input_dim == -1) ? K : activate_input_dim; - Tensor *src0_cal = src0; - Tensor *src1_cal = src1; - const int64_t blck_0 = 16; - for (int b = 0; b < src0->batch(); b++) { - for (int h = 0; h < src0->head(); h++) { - const int b_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : b; - const int h_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : h; - for (int m = 0; m < M; m++) { - const int num_blocks = use_N / blck_0; - const int remainder = use_N % blck_0; -#pragma omp parallel for num_threads(thread_count) - for (int block = 0; block < num_blocks + 1; block++) { - for (int n = block * blck_0; n < (block + 1) * blck_0 & n < num_blocks * blck_0 + remainder; n++) { - int s_1, d_1; - int s_0, d_0; - if (!transpose0 && transpose1) { - s_1 = n; d_1 = 0; s_0 = m; d_0 = 0; - } else if (!transpose0 && !transpose1) { - s_1 = 0; d_1 = n; s_0 = m; d_0 = 0; - } else { - s_1 = 0; d_1 = n; s_0 = 0; d_0 = m; - } - if(dst->dtypeAt(b,h,m,n) == MLLM_TYPE_F32) { - vec_dot_fp32(use_K, dst->ptrAt(b, h, m, n), - src1_cal->hostPtr() + src1_cal->offset(b_1, h_1, s_1, d_1), - src0_cal->hostPtr() + src0_cal->offset(b, h, s_0, d_0)); - if (support_bias) { - *dst->ptrAt(b, h, m, n) += bias->dataAt(0, 0, 0, n); - } - }else if (dst->dtypeAt(b,h,m,n) == MLLM_TYPE_F16) { - float tmp = 0; - vec_dot_fp32(use_K, &tmp, - src1_cal->hostPtr() + src1_cal->offset(b_1, h_1, s_1, d_1), - src0_cal->hostPtr() + src0_cal->offset(b, h, s_0, d_0)); - if (support_bias) { - *dst->ptrAt(b, h, m, n) = MLLM_FP32_TO_FP16(tmp + bias->dataAt(0, 0, 0, n)); - } else { - *dst->ptrAt(b, h, m, n) = MLLM_FP32_TO_FP16(tmp); - } - }else{std::cout<<"Not support type [Matmul]"<dtype() == MLLM_TYPE_F16); - assert(src0_->dtype() == MLLM_TYPE_F32); - Tensor src0_qf16(src0_->shape()); - src0_qf16.setBackend(src0_->backend()); - src0_qf16.setDtype(MLLM_TYPE_F16); - src0_qf16.alloc(); - for (int b = 0; b < src0_->batch(); b++) { - for (int h = 0; h < src0_->head(); h++) { -#pragma omp parallel for num_threads(thread_count) - for (int s = 0; s < src0_->sequence(); s++) { - mllm_fp32_to_fp16_row(src0_->hostPtr() + src0_->offset(b, h, s, 0), - src0_qf16.hostPtr() + src0_qf16.offset(b, h, s, 0), - src0_->dimension()); - } - } - } - auto *src0 = &src0_qf16; - const int M = transpose0 ? src0->dimension() : src0->sequence(); - const int K = transpose0 ? src0->sequence() : src0->dimension(); - const int N = transpose1 ? src1->sequence() : src1->dimension(); - int use_N = (activate_output_dim == -1) ? N : activate_output_dim; - int use_K = (activate_input_dim == -1) ? K : activate_input_dim; - Tensor *src0_cal = src0; - Tensor *src1_cal = src1; - const int64_t blck_0 = 16; - for (int b = 0; b < src0->batch(); b++) { - for (int h = 0; h < src0->head(); h++) { - const int b_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : b; - const int h_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : h; - for (int m = 0; m < M; m++) { - const int num_blocks = use_N / blck_0; - const int remainder = use_N % blck_0; -#pragma omp parallel for num_threads(thread_count) - for (int block = 0; block < num_blocks + 1; block++) { - for (int n = block * blck_0; n < (block + 1) * blck_0 & n < num_blocks * blck_0 + remainder; n++) { - int s_1, d_1; - int s_0, d_0; - if (!transpose0 && transpose1) { - s_1 = n; d_1 = 0; s_0 = m; d_0 = 0; - } else if (!transpose0 && !transpose1) { - s_1 = 0; d_1 = n; s_0 = m; d_0 = 0; - } else { - s_1 = 0; d_1 = n; s_0 = 0; d_0 = m; - } - vec_dot_fp16(use_K, dst->ptrAt(b, h, m, n), - src1_cal->hostPtr() + src1_cal->offset(b_1, h_1, s_1, d_1), - src0_cal->hostPtr() + src0_cal->offset(b, h, s_0, d_0)); - if (support_bias) { - *dst->ptrAt(b, h, m, n) += bias->dataAt(0, 0, 0, n); - } - } - } - } - } - } - return MLLM_NO_ERROR; -} - -ErrorCode mat_mul_elastic_fp32_q4_0(Tensor *src0_, Tensor *src1, Tensor *dst, bool support_bias, Tensor *bias, - int activate_input_dim, int activate_output_dim, int thread_count) { - assert(src1->dtype() == MLLM_TYPE_Q4_0); - assert(src0_->dtype() == MLLM_TYPE_F32); - Tensor src0_q8(src0_->shape()); - src0_q8.setBackend(src0_->backend()); - src0_q8.setDtype(MLLM_TYPE_Q8_0); - src0_q8.alloc(); - if (src0_->dimension() % QK8_0 == 0) { - for (int b = 0; b < src0_->batch(); b++) { - for (int h = 0; h < src0_->head(); h++) { -#pragma omp parallel for num_threads(thread_count) - for (int s = 0; s < src0_->sequence(); s++) { - quantize_row_q8_0(src0_->hostPtr() + src0_->offset(b, h, s, 0), - src0_q8.hostPtr() + src0_q8.offset(b, h, s, 0) / QK8_0, - src0_->dimension()); - } - } - } - } else { - std::cout << "[ERROR]: " << src0_->dimension() << "%" << QK8_0 << "!=0" << std::endl; - assert(src0_->dimension() % QK8_0 == 0); - } - auto *src0 = &src0_q8; - assert(src0->dtype() == MLLM_TYPE_Q8_0); - int M = src0->sequence(); - int K = src0->dimension(); - int N = src1->sequence(); - int use_N = (activate_output_dim == -1) ? N : activate_output_dim; - int use_K = (activate_input_dim == -1) ? K : activate_input_dim; - Tensor *src0_cal = src0; - Tensor *src1_cal = src1; - const int64_t blck_0 = 16; - for (int b = 0; b < src0->batch(); b++) { - for (int h = 0; h < src0->head(); h++) { - const int b_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : b; - const int h_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : h; - for (int m = 0; m < M; m++) { - const int num_blocks = use_N / blck_0; - const int remainder = use_N % blck_0; -#pragma omp parallel for num_threads(thread_count) - for (int block = 0; block < num_blocks + 1; block++) { - for (int n = block * blck_0; n < (block + 1) * blck_0 & n < num_blocks * blck_0 + remainder; n++) { - vec_dot_q4_0_q8_0(use_K, dst->ptrAt(b, h, m, n), - src1_cal->hostPtr() + src1_cal->offset(b_1, h_1, n, 0) / QK4_0, - src0_cal->hostPtr() + src0_cal->offset(b, h, m, 0) / QK8_0); - if (support_bias) { - *dst->ptrAt(b, h, m, n) += bias->dataAt(0, 0, 0, n); - } - } - } - } - } - } - return MLLM_NO_ERROR; -} - -ErrorCode mat_mul_elastic_fp32_q4_K(Tensor *src0_, Tensor *src1, Tensor *dst, bool support_bias, Tensor *bias, - int activate_input_dim, int activate_output_dim, int thread_count) { - assert(src1->dtype() == MLLM_TYPE_Q4_K); - assert(src0_->dtype() == MLLM_TYPE_F32); - Tensor src0_q8(src0_->shape()); - src0_q8.setBackend(src0_->backend()); - src0_q8.setDtype(MLLM_TYPE_Q8_K); - src0_q8.alloc(); - if (src0_->dimension() % QK_K == 0) { - for (int b = 0; b < src0_->batch(); b++) { - for (int h = 0; h < src0_->head(); h++) { -#pragma omp parallel for num_threads(thread_count) - for (int s = 0; s < src0_->sequence(); s++) { - quantize_row_q8_K(src0_->hostPtr() + src0_->offset(b, h, s, 0), - src0_q8.hostPtr() + src0_q8.offset(b, h, s, 0) / QK_K, - src0_->dimension()); - } - } - } - } else { - std::cout << "[ERROR]: " << src0_->dimension() << "%" << QK_K << "!=0" << std::endl; - assert(src0_->dimension() % QK_K == 0); - } - auto *src0 = &src0_q8; - assert(src0->dtype() == MLLM_TYPE_Q8_K); - int M = src0->sequence(); - int K = src0->dimension(); - int N = src1->sequence(); - int use_N = (activate_output_dim == -1) ? N : activate_output_dim; - int use_K = (activate_input_dim == -1) ? K : activate_input_dim; - Tensor *src0_cal = src0; - Tensor *src1_cal = src1; - const int64_t blck_0 = 16; - - for (int b = 0; b < src0->batch(); b++) { - for (int h = 0; h < src0->head(); h++) { - const int b_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : b; - const int h_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : h; - for (int m = 0; m < M; m++) { - const int num_blocks = use_N / blck_0; - const int remainder = use_N % blck_0; -#pragma omp parallel for num_threads(thread_count) - for (int block = 0; block < num_blocks + 1; block++) { - for (int n = block * blck_0; n < (block + 1) * blck_0 & n < num_blocks * blck_0 + remainder; n++) { - if(dst->dtypeAt(b,h,m,n) == MLLM_TYPE_F32) { - vec_dot_q4_K_q8_K(use_K, dst->ptrAt(b, h, m, n), - src1_cal->hostPtr() + src1_cal->offset(b_1, h_1, n, 0) / QK_K, - src0_cal->hostPtr() + src0_cal->offset(b, h, m, 0) / QK_K); - if (support_bias) { - *dst->ptrAt(b, h, m, n) += bias->dataAt(0, 0, 0, n); - } - } else if (dst->dtypeAt(b,h,m,n) == MLLM_TYPE_F16) { - float tmp = 0; - vec_dot_q4_K_q8_K(use_K, &tmp, - src1_cal->hostPtr() + src1_cal->offset(b_1, h_1, n, 0) / QK_K, - src0_cal->hostPtr() + src0_cal->offset(b, h, m, 0) / QK_K); - if (support_bias) { - *dst->ptrAt(b, h, m, n) = MLLM_FP32_TO_FP16(tmp + bias->dataAt(0, 0, 0, n)); - } else { - *dst->ptrAt(b, h, m, n) = MLLM_FP32_TO_FP16(tmp); - } - }else{std::cout<<"Not support type [Matmul]"<dtype() == MLLM_TYPE_Q6_K); - assert(src0_->dtype() == MLLM_TYPE_F32); - Tensor src0_q8(src0_->shape()); - src0_q8.setBackend(src0_->backend()); - src0_q8.setDtype(MLLM_TYPE_Q8_K); - src0_q8.alloc(); - if (src0_->dimension() % QK_K == 0) { - for (int b = 0; b < src0_->batch(); b++) { - for (int h = 0; h < src0_->head(); h++) { -#pragma omp parallel for num_threads(thread_count) - for (int s = 0; s < src0_->sequence(); s++) { - quantize_row_q8_K(src0_->hostPtr() + src0_->offset(b, h, s, 0), - src0_q8.hostPtr() + src0_q8.offset(b, h, s, 0) / QK_K, - src0_->dimension()); - } - } - } - } else { - std::cout << "[ERROR]: " << src0_->dimension() << "%" << QK_K << "!=0" << std::endl; - assert(src0_->dimension() % QK_K == 0); - } - auto *src0 = &src0_q8; - assert(src0->dtype() == MLLM_TYPE_Q8_K); - int M = src0->sequence(); - int K = src0->dimension(); - int N = src1->sequence(); - int use_N = (activate_output_dim == -1) ? N : activate_output_dim; - int use_K = (activate_input_dim == -1) ? K : activate_input_dim; - Tensor *src0_cal = src0; - Tensor *src1_cal = src1; - const int64_t blck_0 = 16; - for (int b = 0; b < src0->batch(); b++) { - for (int h = 0; h < src0->head(); h++) { - const int b_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : b; - const int h_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : h; - for (int m = 0; m < M; m++) { - const int num_blocks = use_N / blck_0; - const int remainder = use_N % blck_0; -#pragma omp parallel for num_threads(thread_count) - for (int block = 0; block < num_blocks + 1; block++) { - for (int n = block * blck_0; n < (block + 1) * blck_0 & n < num_blocks * blck_0 + remainder; n++) { - if (dst->dtypeAt(n, h, m, n) == MLLM_TYPE_F32) { - vec_dot_q6_K_q8_K(use_K, dst->ptrAt(b, h, m, n), - src1_cal->hostPtr() + src1_cal->offset(b_1, h_1, n, 0) / QK_K, - src0_cal->hostPtr() + src0_cal->offset(b, h, m, 0) / QK_K); - if (support_bias) { - *dst->ptrAt(b, h, m, n) += bias->dataAt(0, 0, 0, n); - } - } else if (dst->dtypeAt(n, h, m, n) == MLLM_TYPE_F16) { - float tmp = 0; - vec_dot_q6_K_q8_K(use_K, &tmp, - src1_cal->hostPtr() + src1_cal->offset(b_1, h_1, n, 0) / QK_K, - src0_cal->hostPtr() + src0_cal->offset(b, h, m, 0) / QK_K); - - if (support_bias) { - *dst->ptrAt(b, h, m, n) = MLLM_FP32_TO_FP16(tmp + bias->dataAt(0, 0, 0, n)); - } else { - *dst->ptrAt(b, h, m, n) = MLLM_FP32_TO_FP16(tmp); - } - } else { - std::cout << "Not support tupe [Matmul]" << std::endl; - } - } - } - } - } - } - return MLLM_NO_ERROR; -} - diff --git a/src/backends/cpu/compute/MatmulElastic.hpp b/src/backends/cpu/compute/MatmulElastic.hpp deleted file mode 100644 index cd4bef9d..00000000 --- a/src/backends/cpu/compute/MatmulElastic.hpp +++ /dev/null @@ -1,18 +0,0 @@ -// -// Created by Rongjie Yi on 23-10-24. -// - -#ifndef MLLM_MATMUL_HPP -#define MLLM_MATMUL_HPP - - -#include "VecDot.hpp" -using namespace mllm; - -ErrorCode mat_mul_elastic_fp32(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Tensor *bias = nullptr, int activate_input_dim=-1, int activate_output_dim=-1, bool transpose0 = false, bool transpose1 = false, int thread_count=4); -ErrorCode mat_mul_elastic_fp32_fp16(Tensor *src0_, Tensor *src1, Tensor *dst, bool support_bias, Tensor *bias = nullptr, int activate_input_dim=-1, int activate_output_dim=-1,bool transpose0 = false, bool transpose1 = false, int thread_count=4); -ErrorCode mat_mul_elastic_fp32_q4_0(Tensor *src0_, Tensor *src1, Tensor *dst, bool support_bias, Tensor *bias = nullptr, int activate_input_dim=-1, int activate_output_dim=-1,int thread_count=4); -ErrorCode mat_mul_elastic_fp32_q4_K(Tensor *src0_, Tensor *src1, Tensor *dst, bool support_bias, Tensor *bias = nullptr, int activate_input_dim=-1, int activate_output_dim=-1,int thread_count=4); -ErrorCode mat_mul_elastic_fp32_q6_K(Tensor *src0_, Tensor *src1, Tensor *dst, bool support_bias, Tensor *bias = nullptr, int activate_input_dim=-1, int activate_output_dim=-1,int thread_count=4); - -#endif // MLLM_MATMUL_HPP diff --git a/src/backends/cpu/compute/SGEMM.cpp b/src/backends/cpu/compute/SGEMM.cpp new file mode 100644 index 00000000..238d9621 --- /dev/null +++ b/src/backends/cpu/compute/SGEMM.cpp @@ -0,0 +1,1127 @@ +// Copyright 2024 Mozilla Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS +// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// +// _ _ ___ _ _ ___ +// | |_(_)_ _ _ _| _ ) | /_\ / __| +// | _| | ' \ || | _ \ |__ / _ \\__ \. +// \__|_|_||_\_, |___/____/_/ \_\___/ +// |__/ +// +// BASIC LINEAR ALGEBRA SUBPROGRAMS +// +// +// This file implements multithreaded CPU matrix multiplication for the +// common contiguous use case C = Aᵀ * B. These kernels are designed to +// have excellent performance[1] for matrices that fit in the CPU cache +// without imposing any overhead such as cache filling or malloc calls. +// +// This implementation does not guarantee any upper bound with rounding +// errors, which grow along with k. Our goal's to maximally exploit the +// hardware for performance, and then use whatever resources remain for +// improving numerical accuracy. +// +// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online]. +// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024]. + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Wpedantic" +#pragma GCC diagnostic ignored "-Wignored-attributes" +#endif + + +#include "SGEMM.hpp" + +#ifdef _MSC_VER +#define NOINLINE __declspec(noinline) +#else +#define NOINLINE __attribute__((__noinline__)) +#endif + +#if defined(__ARM_NEON) || defined(__AVX512F__) +#define VECTOR_REGISTERS 32 +#else +#define VECTOR_REGISTERS 16 +#endif + +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + +namespace { + +inline float unhalf(mllm_fp16_t d) { + return MLLM_FP16_TO_FP32(d); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// VECTORIZED ARITHMETIC OPERATIONS + +#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); } +inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); } +inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); } +#endif // __SSE__ + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); } +inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); } +inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); } +#endif // __AVX__ + +#if defined(__AVX512F__) +inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); } +inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); } +inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); } +#endif // __AVX512F__ + +#if defined(__ARM_NEON) +inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); } +inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); } +inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); } +#endif // __ARM_NEON + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); } +inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); } +inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// VECTORIZED FUSED MULTIPLY ADD + +/** + * Computes a * b + c. + */ +template +inline U madd(T a, T b, U c) { + return add(mul(a, b), c); +} + +#if defined(__FMA__) +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +template <> +inline __m256 madd(__m256 a, __m256 b, __m256 c) { + return _mm256_fmadd_ps(a, b, c); +} +#endif +#if defined(__AVX512F__) +template <> +inline __m512 madd(__m512 a, __m512 b, __m512 c) { + return _mm512_fmadd_ps(a, b, c); +} +#endif +#endif + +#if defined(__ARM_FEATURE_FMA) +template <> +inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { + return vfmaq_f32(c, b, a); +} +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) +template <> +inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) { + return vfmaq_f16(c, b, a); +} +#endif +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// VECTORIZED HORIZONTAL SUM + +#if defined(__ARM_NEON) +inline float hsum(float32x4_t x) { + return vaddvq_f32(x); +} +#endif // __ARM_NEON + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) +inline float hsum(float16x8_t x) { + return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)), + vcvt_f32_f16(vget_high_f16(x)))); +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +inline float hsum(__m128 x) { +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) + x = _mm_add_ps(x, _mm_movehl_ps(x, x)); + x = _mm_add_ss(x, _mm_movehdup_ps(x)); +#else + __m128 t; + t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1)); + x = _mm_add_ps(x, t); + t = _mm_movehl_ps(t, x); + x = _mm_add_ss(x, t); +#endif + return _mm_cvtss_f32(x); +} +#endif + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +inline float hsum(__m256 x) { + return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1), + _mm256_castps256_ps128(x))); +} +#endif // __AVX__ + +#if defined(__AVX512F__) +inline float hsum(__m512 x) { + return _mm512_reduce_add_ps(x); +} +#endif // __AVX512F__ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// VECTORIZED MEMORY LOADING + +template T load(const U *); + +#if defined(__ARM_NEON) +template <> inline float32x4_t load(const float *p) { + return vld1q_f32(p); +} +#if !defined(_MSC_VER) +template <> inline float16x8_t load(const mllm_fp16_t *p) { + return vld1q_f16((const float16_t *)p); +} +template <> inline float32x4_t load(const mllm_fp16_t *p) { + return vcvt_f32_f16(vld1_f16((const float16_t *)p)); +} +#endif // _MSC_VER +#endif // __ARM_NEON + +#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +template <> inline __m128 load(const float *p) { + return _mm_loadu_ps(p); +} +#endif // __SSE__ + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +template <> inline __m256 load(const float *p) { + return _mm256_loadu_ps(p); +} +#endif // __AVX__ + +#if defined(__F16C__) +template <> inline __m256 load(const mllm_fp16_t *p) { + return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p)); +} +#endif // __F16C__ + +#if defined(__AVX512F__) +template <> inline __m512 load(const float *p) { + return _mm512_loadu_ps(p); +} +template <> inline __m512 load(const mllm_fp16_t *p) { + return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p)); +} +#endif // __AVX512F__ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// FLOATING POINT MATRIX MULTIPLICATION + +template +class tinyBLAS { + public: + tinyBLAS(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int64_t m, int64_t n) { + mnpack(0, m, 0, n); + } + + private: + NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) { +#if VECTOR_REGISTERS == 32 + case 0x55: + mc = 5; + nc = 5; + gemm<5, 5>(m0, m, n0, n); + break; + case 0x45: + mc = 4; + nc = 5; + gemm<4, 5>(m0, m, n0, n); + break; + case 0x54: + mc = 5; + nc = 4; + gemm<5, 4>(m0, m, n0, n); + break; + case 0x44: + mc = 4; + nc = 4; + gemm<4, 4>(m0, m, n0, n); + break; + case 0x53: + mc = 5; + nc = 3; + gemm<5, 3>(m0, m, n0, n); + break; + case 0x35: + mc = 3; + nc = 5; + gemm<3, 5>(m0, m, n0, n); + break; + case 0x43: + mc = 4; + nc = 3; + gemm<4, 3>(m0, m, n0, n); + break; +#else + case 0x55: + case 0x54: + case 0x53: + case 0x45: + case 0x44: + case 0x43: + mc = 4; + nc = 3; + gemm<4, 3>(m0, m, n0, n); + break; + case 0x35: +#endif + case 0x34: + mc = 3; + nc = 4; + gemm<3, 4>(m0, m, n0, n); + break; + case 0x52: + mc = 5; + nc = 2; + gemm<5, 2>(m0, m, n0, n); + break; + case 0x33: + mc = 3; + nc = 3; + gemm<3, 3>(m0, m, n0, n); + break; + case 0x25: + mc = 2; + nc = 5; + gemm<2, 5>(m0, m, n0, n); + break; + case 0x42: + mc = 4; + nc = 2; + gemm<4, 2>(m0, m, n0, n); + break; + case 0x24: + mc = 2; + nc = 4; + gemm<2, 4>(m0, m, n0, n); + break; + case 0x32: + mc = 3; + nc = 2; + gemm<3, 2>(m0, m, n0, n); + break; + case 0x23: + mc = 2; + nc = 3; + gemm<2, 3>(m0, m, n0, n); + break; + case 0x51: + mc = 5; + nc = 1; + gemm<5, 1>(m0, m, n0, n); + break; + case 0x41: + mc = 4; + nc = 1; + gemm<4, 1>(m0, m, n0, n); + break; + case 0x22: + mc = 2; + nc = 2; + gemm<2, 2>(m0, m, n0, n); + break; + case 0x15: + mc = 1; + nc = 5; + gemm<1, 5>(m0, m, n0, n); + break; + case 0x14: + mc = 1; + nc = 4; + gemm<1, 4>(m0, m, n0, n); + break; + case 0x31: + mc = 3; + nc = 1; + gemm<3, 1>(m0, m, n0, n); + break; + case 0x13: + mc = 1; + nc = 3; + gemm<1, 3>(m0, m, n0, n); + break; + case 0x21: + mc = 2; + nc = 1; + gemm<2, 1>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm<1, 2>(m0, m, n0, n); + break; + case 0x11: + mc = 1; + nc = 1; + gemm<1, 1>(m0, m, n0, n); + break; + default: + return; + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, m, np, n); + } + + template + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + D Cv[RN][RM] = {}; + for (int64_t l = 0; l < k; l += KN) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) + Cv[j][i] = madd(load(A + lda * (ii + i) + l), + load(B + ldb * (jj + j) + l), + Cv[j][i]); + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + } + } + + const TA *const A; + const TB *const B; + TC *const C; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; + +////////////////////////////////////////////////////////////////////////////////////////// +// QUANT ZERO MATRIX MULTIPLICATION + +#if defined(__ARM_FEATURE_DOTPROD) +template +class tinyBLAS_Q0_ARM { + public: + tinyBLAS_Q0_ARM(int64_t k, + const TA *A, int64_t lda, + const block_q8_0 *B, int64_t ldb, + float *C, int64_t ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int64_t m, int64_t n) { + mnpack(0, m, 0, n); + } + + private: + NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) { + case 0x33: + mc = 3; + nc = 3; + gemm<3, 3>(m0, m, n0, n); + break; + case 0x32: + mc = 3; + nc = 2; + gemm<3, 2>(m0, m, n0, n); + break; + case 0x23: + mc = 2; + nc = 3; + gemm<2, 3>(m0, m, n0, n); + break; + case 0x22: + mc = 2; + nc = 2; + gemm<2, 2>(m0, m, n0, n); + break; + case 0x31: + mc = 3; + nc = 1; + gemm<3, 1>(m0, m, n0, n); + break; + case 0x13: + mc = 1; + nc = 3; + gemm<1, 3>(m0, m, n0, n); + break; + case 0x21: + mc = 2; + nc = 1; + gemm<2, 1>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm<1, 2>(m0, m, n0, n); + break; + case 0x11: + mc = 1; + nc = 1; + gemm<1, 1>(m0, m, n0, n); + break; + default: + return; + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, m, np, n); + } + + template + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + float32x4_t Cv[RN][RM] = {}; + for (int64_t l = 0; l < k; ++l) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) + Cv[j][i] = vmlaq_n_f32(Cv[j][i], + vcvtq_f32_s32(vdotq_s32( + vdotq_s32(vdupq_n_s32(0), + load_lo(A + lda * (ii + i) + l), + load_lo(B + ldb * (jj + j) + l)), + load_hi(A + lda * (ii + i) + l), + load_hi(B + ldb * (jj + j) + l))), + unhalf(A[lda * (ii + i) + l].d) * + unhalf(B[ldb * (jj + j) + l].d)); + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + } + } + + inline int8x16_t load_lo(const block_q8_0 *b) { + return vld1q_s8(b->qs); + } + + inline int8x16_t load_hi(const block_q8_0 *b) { + return vld1q_s8(b->qs + 16); + } + + inline int8x16_t load_lo(const block_q4_0 *b) { + return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), + vdupq_n_u8(0x0f))), + vdupq_n_s8(0x8)); + } + + inline int8x16_t load_hi(const block_q4_0 *b) { + return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), + vdupq_n_s8(0x8)); + } + + const TA *const A; + const block_q8_0 *const B; + float *const C; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; +#endif // __ARM_FEATURE_DOTPROD + +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) +template +class tinyBLAS_Q0_AVX { + public: + tinyBLAS_Q0_AVX(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int64_t m, int64_t n) { + mnpack(0, m, 0, n); + } + + private: + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) { +#if VECTOR_REGISTERS == 32 + case 0x44: + mc = 4; + nc = 4; + gemm<4, 4>(m0, m, n0, n); + break; + case 0x43: + mc = 4; + nc = 3; + gemm<4, 3>(m0, m, n0, n); + break; + case 0x34: + mc = 3; + nc = 4; + gemm<3, 4>(m0, m, n0, n); + break; + case 0x33: + mc = 3; + nc = 3; + gemm<3, 3>(m0, m, n0, n); + break; + case 0x42: + mc = 4; + nc = 2; + gemm<4, 2>(m0, m, n0, n); + break; + case 0x24: + mc = 2; + nc = 4; + gemm<2, 4>(m0, m, n0, n); + break; +#else + case 0x44: + case 0x43: + case 0x42: + mc = 4; + nc = 2; + gemm<4, 2>(m0, m, n0, n); + break; + case 0x34: + case 0x24: + mc = 2; + nc = 4; + gemm<2, 4>(m0, m, n0, n); + break; + case 0x33: +#endif + case 0x32: + mc = 3; + nc = 2; + gemm<3, 2>(m0, m, n0, n); + break; + case 0x23: + mc = 2; + nc = 3; + gemm<2, 3>(m0, m, n0, n); + break; + case 0x41: + mc = 4; + nc = 1; + gemm<4, 1>(m0, m, n0, n); + break; + case 0x22: + mc = 2; + nc = 2; + gemm<2, 2>(m0, m, n0, n); + break; + case 0x14: + mc = 1; + nc = 4; + gemm<1, 4>(m0, m, n0, n); + break; + case 0x31: + mc = 3; + nc = 1; + gemm<3, 1>(m0, m, n0, n); + break; + case 0x13: + mc = 1; + nc = 3; + gemm<1, 3>(m0, m, n0, n); + break; + case 0x21: + mc = 2; + nc = 1; + gemm<2, 1>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm<1, 2>(m0, m, n0, n); + break; + case 0x11: + mc = 1; + nc = 1; + gemm<1, 1>(m0, m, n0, n); + break; + default: + return; + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, m, np, n); + } + + template + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + __m256 Cv[RN][RM] = {}; + for (int64_t l = 0; l < k; ++l) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) { +#if defined(__AVX2__) + __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), + load(A + lda * (ii + i) + l))); +#else + __m128i ali0 = load0(A + lda * (ii + i) + l); + __m128i ali1 = load1(A + lda * (ii + i) + l); + __m128i blj0 = load0(B + ldb * (jj + j) + l); + __m128i blj1 = load1(B + ldb * (jj + j) + l); + + __m128i sepAA0 = _mm_sign_epi8(ali0, ali0); + __m128i sepAA1 = _mm_sign_epi8(ali1, ali1); + __m128i sepBA0 = _mm_sign_epi8(blj0, ali0); + __m128i sepBA1 = _mm_sign_epi8(blj1, ali1); + + // updot + const __m128i oneFill = _mm_set1_epi16(1); + __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0); + __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1); + __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0))); +#endif + Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * + unhalf(B[ldb * (jj + j) + l].d)), + udTmp, + Cv[j][i]); + } + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + } + } + + inline __m256i load(const block_q8_0 *b) { + return _mm256_loadu_si256((const __m256i *)b->qs); + } + + inline __m128i load0(const block_q8_0 *b) { + return _mm_loadu_si128((const __m128i *)b->qs); + } + + inline __m128i load1(const block_q8_0 *b) { + return _mm_loadu_si128(((const __m128i *)b->qs) + 1); + } + + inline __m256i load(const block_q4_0 *b) { + return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); + } + + inline __m128i load0(const block_q4_0 *b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8)); + } + + inline __m128i load1(const block_q4_0 *b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8)); + } + + inline __m256 updot(__m256i u, __m256i s) { + __m256i res; +#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) + res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); +#else + res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); +#endif + return _mm256_cvtepi32_ps(res); + } + + static inline __m256i denibble(const uint8_t *p) { + __m128i x = _mm_loadu_si128((const __m128i *)p); + return _mm256_and_si256(_mm256_set1_epi8(15), + _mm256_insertf128_si256(_mm256_castsi128_si256(x), + _mm_srli_epi16(x, 4), 1)); + } + + const TA *const A; + const TB *const B; + TC *const C; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; +#endif // __AVX__ + +} // namespace + +/** + * Performs optimized matrix multiplication on CPU. + * + * This subroutine may compute C = Aᵀ * B with column major ordering. + * Despite its name, this isn't a generalized implementation. Work is + * only performed when a handwritten kernel is written and available. + * Otherwise the caller should fall back to a general matmul routine. + * + * For example, for single-threaded single-precision GEMM you can say + * + * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, + * 0, 1, + * MLLM_TYPE_F32, MLLM_TYPE_F32, MLLM_TYPE_F32); + * + * @param m is rows in `A` and `C` + * @param n is cols in `B` and `C` + * @param k is cols in `A` and rows in `B` + * @param A is first input matrix (always transposed) + * @param lda is row stride of `A` + * @param B is second input matrix (never transposed) + * @param ldb is row stride of `B` + * @param C is input/output array of output matrices + * @param ldc is row stride of `C` + * @param ith is thread id (must be less than `nth`) + * @param nth is number of threads (must be greater than zero) + * @param Atype is GGML data type of `A` + * @param Btype is GGML data type of `B` + * @param Ctype is GGML data type of `C` + * @return true if this function was able to service the matmul request + */ + //TODOYRJ +bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,int64_t ldc, + int ith, int nth, + DataType Atype, DataType Btype, DataType Ctype) { + assert(m >= 0); + assert(n >= 0); + assert(k >= 0); + assert(lda >= k); + assert(ldb >= k); + assert(ldc >= m); + assert(nth > 0); + assert(ith < nth); + + if (Ctype != MLLM_TYPE_F32) + return false; + + switch (Atype) { + + case MLLM_TYPE_F32: { + if (Btype != MLLM_TYPE_F32) + return false; +#if defined(__AVX512F__) + if (k % 16) + return false; + tinyBLAS<16, __m512, __m512, float, float, float> tb{ + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#elif defined(__AVX__) || defined(__AVX2__) + if (k % 8) + return false; + tinyBLAS<8, __m256, __m256, float, float, float> tb{ + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#elif defined(__ARM_NEON) + if (n < 4) + return false; + if (k % 4) + return false; + tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#else + return false; +#endif + } + + case MLLM_TYPE_F16: { +#if defined(__AVX512F__) + if (k % 16) + return false; + if (Btype != MLLM_TYPE_F32) + return false; + tinyBLAS<16, __m512, __m512, mllm_fp16_t, float, float> tb{ + k, (const mllm_fp16_t *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) + if (k % 8) + return false; + if (Btype != MLLM_TYPE_F32) + return false; + tinyBLAS<8, __m256, __m256, mllm_fp16_t, float, float> tb{ + k, (const mllm_fp16_t *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) + if (n < 8) + return false; + if (k % 8) + return false; + if (Btype != MLLM_TYPE_F16) + return false; + tinyBLAS<8, float16x8_t, float16x8_t, mllm_fp16_t, mllm_fp16_t, float> tb{ + k, (const mllm_fp16_t *)A, lda, + (const mllm_fp16_t *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#elif defined(__ARM_NEON) && !defined(_MSC_VER) + if (k % 4) + return false; + if (Btype != MLLM_TYPE_F32) + return false; + tinyBLAS<4, float32x4_t, float32x4_t, mllm_fp16_t, float, float> tb{ + k, (const mllm_fp16_t *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#else + return false; +#endif + } + + case MLLM_TYPE_Q8_0: { + if (Btype != MLLM_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) + tinyBLAS_Q0_AVX tb{ + k, (const block_q8_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#elif defined(__ARM_FEATURE_DOTPROD) + tinyBLAS_Q0_ARM tb{ + k, (const block_q8_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#else + return false; +#endif + } + + case MLLM_TYPE_Q4_0: { + if (Btype != MLLM_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) + tinyBLAS_Q0_AVX tb{ + k, (const block_q4_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#elif defined(__ARM_FEATURE_DOTPROD) + tinyBLAS_Q0_ARM tb{ + k, (const block_q4_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; +#else + return false; +#endif + } + + default: + return false; + } + + (void)m; + (void)n; + (void)k; + (void)A; + (void)lda; + (void)B; + (void)ldb; + (void)C; + (void)ldc; + (void)ith; + (void)nth; + (void)Atype; + (void)Btype; + (void)Ctype; +} + + +bool check_llamafile_sgemm(int64_t m, int64_t n, int64_t k, DataType Atype, DataType Btype, DataType Ctype) { + + int ith=0; + int nth =1; + assert(m >= 0); + assert(n >= 0); + assert(k >= 0); + assert(nth > 0); + assert(ith < nth); + + if (Ctype != MLLM_TYPE_F32) + return false; + + switch (Atype) { + + case MLLM_TYPE_F32: { + // return false; //TODO CHECK THIS CALUATE + if (Btype != MLLM_TYPE_F32) + return false; +#if defined(__AVX512F__) + if (k % 16) + return false; + return true; +#elif defined(__AVX__) || defined(__AVX2__) + if (k % 8) + return false; + return true; +#elif defined(__ARM_NEON) + if (n < 4) + return false; + if (k % 4) + return false; + return true; +#else + return false; +#endif + } + + case MLLM_TYPE_F16: { +#if defined(__AVX512F__) + if (k % 16) + return false; + if (Btype != MLLM_TYPE_F32) + return false; + return true; +#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) + if (k % 8) + return false; + if (Btype != MLLM_TYPE_F32) + return false; + return true; +#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) + if (n < 8) + return false; + if (k % 8) + return false; + if (Btype != MLLM_TYPE_F16) + return false; + return true; +#elif defined(__ARM_NEON) && !defined(_MSC_VER) + if (k % 4) + return false; + if (Btype != MLLM_TYPE_F32) + return false; + return true; +#else + return false; +#endif + } + + case MLLM_TYPE_Q8_0: { + if (Btype != MLLM_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) + return true; +#elif defined(__ARM_FEATURE_DOTPROD) + return true; +#else + return false; +#endif + } + + case MLLM_TYPE_Q4_0: { + if (Btype != MLLM_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) + return true; +#elif defined(__ARM_FEATURE_DOTPROD) + return true; +#else + return false; +#endif + } + + default: + return false; + } +} diff --git a/src/backends/cpu/compute/SGEMM.hpp b/src/backends/cpu/compute/SGEMM.hpp new file mode 100644 index 00000000..f4e12c18 --- /dev/null +++ b/src/backends/cpu/compute/SGEMM.hpp @@ -0,0 +1,20 @@ +// +// Created by Rongjie Yi on 24-07-23. +// + +#ifndef MLLM_GEMM_HPP +#define MLLM_GEMM_HPP + + +#include "VecDot.hpp" +using namespace mllm; + +bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C, int64_t ldc, + int ith, int nth, + DataType Atype, DataType Btype, DataType Ctype); + +bool check_llamafile_sgemm(int64_t, int64_t, int64_t, + DataType, DataType, DataType); + + +#endif // MLLM_GEMM_HPP diff --git a/src/backends/cpu/compute/VecDot.hpp b/src/backends/cpu/compute/VecDot.hpp index 1dad4968..ac09d753 100644 --- a/src/backends/cpu/compute/VecDot.hpp +++ b/src/backends/cpu/compute/VecDot.hpp @@ -356,6 +356,32 @@ inline static int32x4_t mllm_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) using namespace mllm; +inline static void vec_scale_f32(const int n, float *y, const float v) { + const int np = (n & ~(MLLM_F32_STEP - 1)); + + MLLM_F32_VEC vx = MLLM_F32_VEC_SET1(v); + + MLLM_F32_VEC ay[MLLM_F32_ARR]; + + for (int i = 0; i < np; i += MLLM_F32_STEP) { + for (int j = 0; j < MLLM_F32_ARR; j++) { + ay[j] = MLLM_F32_VEC_LOAD(y + i + j * MLLM_F32_EPR); + ay[j] = MLLM_F32_VEC_MUL(ay[j], vx); + + MLLM_F32_VEC_STORE(y + i + j * MLLM_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] *= v; + } + + // for (int i = 0; i < n; ++i) { + // y[i] *= v; + // } +} + // void vec_dot_fp32(const float * __restrict src0, const float * __restrict src1, Tensor *dst, bool support_bias, Tensor *bias, int hid_len, int batch, int head, int src0_inf, int sec1_outf); void vec_dot_q4_0_q8_0(const void * __restrict src0, const void * __restrict src1, Tensor *dst, bool support_bias, Tensor *bias, int hid_len, int batch, int head, int src0_inf, int sec1_outf); void vec_dot_q4_K_q8_K(const void * __restrict src0, const void * __restrict src1, Tensor *dst, bool support_bias, Tensor *bias, int hid_len, int batch, int head, int src0_inf, int sec1_outf); diff --git a/src/backends/cpu/compute/VecDotType.cpp b/src/backends/cpu/compute/VecDotType.cpp index 843c1a0f..3fedbd07 100644 --- a/src/backends/cpu/compute/VecDotType.cpp +++ b/src/backends/cpu/compute/VecDotType.cpp @@ -32,6 +32,7 @@ #include "quantize/Quantize.hpp" #include "quantize/QuantizeQ6.hpp" #include "compute/VecDot.hpp" +#include "compute/GEMM_AArch64.hpp" void fp32_add_row_to(int n, const float * MLLM_RESTRICT src, float * MLLM_RESTRICT dst, float alpha){ @@ -258,6 +259,7 @@ type_traits_t type_traits[] = { .blck_size = QK8_0, .to_float = (mllm_to_float_func) dequantize_row_q8_0, .from_float = (mllm_from_float_func) quantize_row_q8_0, + .from_float_to_mat = (mllm_from_float_to_mat_func)quantize_mat_q8_0, .vec_dot = (mllm_vec_dot_func) vec_dot_q8_0_q8_0, .vec_dot_type = MLLM_TYPE_Q8_0, .add_row_to = (mllm_vec_add_row_func)q8_0_add_row_to, @@ -296,6 +298,47 @@ type_traits_t type_traits[] = { /*[MLLM_TYPE_I_8] = */{}, /*[MLLM_TYPE_I_16] = */{}, /*[MLLM_TYPE_I_32] = */{}, + /*[MLLM_TYPE_Q4_0_4_4] = */{ + .size = sizeof(block_q4_0), + .blck_size = QK4_0, + .blck_size_interleave = 4, + .to_float = NULL, + .from_float = NULL, + .vec_dot = NULL, + .vec_dot_type = MLLM_TYPE_Q8_0, + // .nrows = 1, + // .ncols = 4, + .gemv = (mllm_gemv_func)mllm_gemv_q4_0_4x4_q8_0, + .gemm = (mllm_gemm_func)mllm_gemm_q4_0_4x4_q8_0, + }, + /*[MLLM_TYPE_Q4_0_4_8] = */{ + .size = sizeof(block_q4_0), + .blck_size = QK4_0, + .blck_size_interleave = 8, + // .is_quantized = true, + .to_float = NULL, + .from_float = NULL, + .vec_dot = NULL, + .vec_dot_type = MLLM_TYPE_Q8_0, + // .nrows = 1, + // .ncols = 4, + .gemv = (mllm_gemv_func)mllm_gemv_q4_0_4x8_q8_0, + .gemm = (mllm_gemv_func)mllm_gemm_q4_0_4x8_q8_0, + }, + /*[MLLM_TYPE_Q4_0_8_8] = */{ + .size = sizeof(block_q4_0), + .blck_size = QK4_0, + .blck_size_interleave = 8, + // .is_quantized = true, + .to_float = NULL, + .from_float = NULL, + .vec_dot = NULL, + .vec_dot_type = MLLM_TYPE_Q8_0, + // .nrows = 1, + // .ncols = 8, + .gemv = (mllm_gemv_func)mllm_gemv_q4_0_8x8_q8_0, + .gemm = (mllm_gemv_func)mllm_gemm_q4_0_8x8_q8_0, + }, {}, // TODO: add support to more type }; diff --git a/src/backends/cpu/compute/VecDotType.hpp b/src/backends/cpu/compute/VecDotType.hpp index e87918a3..f37dc82a 100644 --- a/src/backends/cpu/compute/VecDotType.hpp +++ b/src/backends/cpu/compute/VecDotType.hpp @@ -34,16 +34,25 @@ typedef void (*mllm_to_float_func)(const void *src, float *dst, const int n); // from src type to float(stored in dst) n is the number of element in src typedef void (*mllm_from_float_func)(const float *src, void *dst, const int n); typedef void (*mllm_vec_dot_func) (const int n, float * MLLM_RESTRICT dst, const void * MLLM_RESTRICT x, const void * MLLM_RESTRICT y); +typedef void (*mllm_from_float_to_mat_func)(const float * MLLM_RESTRICT x, void * MLLM_RESTRICT y, int64_t nr, int64_t k, int64_t bs); typedef void (*mllm_vec_add_row_func) (const int n, const void * MLLM_RESTRICT src, float * MLLM_RESTRICT dst, const float alpha); +typedef void (*mllm_gemv_func) (int n, float * MLLM_RESTRICT s, size_t bs, const void * MLLM_RESTRICT x, + const void * MLLM_RESTRICT y, int nr, int nc); +typedef void (*mllm_gemm_func) (int n, float * MLLM_RESTRICT s, size_t bs, const void * MLLM_RESTRICT x, + const void * MLLM_RESTRICT y, int nr, int nc); typedef struct type_traits_t{ size_t size; // type size int blck_size; // number of element in a block (quantization block) + int blck_size_interleave; mllm_to_float_func to_float; mllm_from_float_func from_float; + mllm_from_float_to_mat_func from_float_to_mat; mllm_vec_dot_func vec_dot; DataType vec_dot_type; // vec_dot do dot product between two DataType, this is the other type mllm_vec_add_row_func add_row_to; // add alpha * row to a row of float + mllm_gemv_func gemv; + mllm_gemm_func gemm; }type_traits_t; extern type_traits_t type_traits[]; diff --git a/src/backends/cpu/quantize/Quantize.hpp b/src/backends/cpu/quantize/Quantize.hpp index 3f2a413a..7c995686 100644 --- a/src/backends/cpu/quantize/Quantize.hpp +++ b/src/backends/cpu/quantize/Quantize.hpp @@ -226,15 +226,15 @@ inline void init_table_silu_f16() { mllm_table_silu_f16[i] = MLLM_FP32_TO_FP16(mllm_silu_f32(f)); } } -inline static void mllm_vec_silu_f32(const int n, float * y, const float * x) { - uint16_t t; -//#pragma omp parallel for num_threads(thread_count) - for (int i = 0; i < n; ++i) { - mllm_fp16_t fp16 = MLLM_FP32_TO_FP16(x[i]); - memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = MLLM_FP16_TO_FP32(mllm_table_silu_f16[t]); - } -} +// inline static void mllm_vec_silu_f32(const int n, float * y, const float * x) { +// uint16_t t; +// //#pragma omp parallel for num_threads(thread_count) +// for (int i = 0; i < n; ++i) { +// mllm_fp16_t fp16 = MLLM_FP32_TO_FP16(x[i]); +// memcpy(&t, &fp16, sizeof(uint16_t)); +// y[i] = MLLM_FP16_TO_FP32(mllm_table_silu_f16[t]); +// } +// } diff --git a/src/express/Express.cpp b/src/express/Express.cpp index 01218a98..46527750 100644 --- a/src/express/Express.cpp +++ b/src/express/Express.cpp @@ -228,7 +228,7 @@ NetTensor *_SiLU(std::vector inputs, string name) { /** * \param axis The axis along which the softmax is performed. e.g. DIMENSION. */ -NetTensor *_Softmax(std::vector inputs, int axis, string name) { +NetTensor *_Softmax(std::vector inputs, int axis, int do_causal_mask, string name) { Context *ctx = inputs[0]->ctx; NetTensor *out_tensor = new NetTensor(); if (name.empty()) { @@ -240,6 +240,7 @@ NetTensor *_Softmax(std::vector inputs, int axis, string name) { _STORE_OUT_TENSOR _NEW_OP(mllm::SOFTMAX) net_op_->param["axis"] = axis; + net_op_->param["do_causal_mask"] = do_causal_mask; _UPDATE_INPUT_TENSORS out_tensor->in = net_op_; out_tensor->ctx = ctx; diff --git a/src/express/Express.hpp b/src/express/Express.hpp index 6126c674..a28613bc 100644 --- a/src/express/Express.hpp +++ b/src/express/Express.hpp @@ -17,7 +17,7 @@ NetTensor *_Range(Context *ctx, std::vector inputs, int start, int NetTensor *_Add(std::vector inputs, string name = ""); NetTensor *_Causalmask(std::vector inputs, string name = ""); NetTensor *_SiLU(std::vector inputs, string name = ""); -NetTensor *_Softmax(std::vector inputs, int axis, string name = ""); +NetTensor *_Softmax(std::vector inputs, int axis, int do_causal_mask, string name = ""); NetTensor *_Matmul(std::vector inputs, bool transpose0, bool transpose1, string name = ""); NetTensor *_RMSNorm(std::vector inputs, int norm_size, float epsilon= 1e-6, string name = ""); NetTensor *_RoPE(std::vector inputs, int pose_type, string name = ""); diff --git a/src/models/clip/modeling_clip.hpp b/src/models/clip/modeling_clip.hpp index 4797586d..c2cb383b 100644 --- a/src/models/clip/modeling_clip.hpp +++ b/src/models/clip/modeling_clip.hpp @@ -88,7 +88,7 @@ class ClipTextBlock final : public Module { ClipTextBlock() = default; ClipTextBlock(int hidden_dim, int head_size, int ffn_hidden, const string &act_fn_type, const ClipTextNameConfig &names, const string &base_name) { attention = MultiHeadAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, SPLIT_NONE, false, false, - RoPEType::NONE, 0, true, true, names, base_name + names._attn_base_name); + RoPEType::NONE, -1,-1, 0, true, true, names, base_name + names._attn_base_name); mlp = ClipTextMLP(hidden_dim, ffn_hidden, act_fn_type, names, base_name + names._ffn_base_name); down_proj = Linear(ffn_hidden, hidden_dim, true, base_name + names._down_proj_name); norm1 = LayerNorm(hidden_dim, true, 1e-6, base_name + names._attn_norm_name); diff --git a/src/models/deepseek/modeling_deepseek.hpp b/src/models/deepseek/modeling_deepseek.hpp index ca4fb034..3fbece36 100644 --- a/src/models/deepseek/modeling_deepseek.hpp +++ b/src/models/deepseek/modeling_deepseek.hpp @@ -18,10 +18,9 @@ class DeepseekMultiHeadLatentAttention final : public Module { Layer v_proj; Layer q_rope; Layer k_rope; - Layer k_cache; - Layer v_cache; - Layer mask; - Layer softmax; + KVCache k_cache; + KVCache v_cache; + Softmax softmax; Layer o_proj; int num_heads{}; int q_head_dim{}; @@ -69,10 +68,7 @@ class DeepseekMultiHeadLatentAttention final : public Module { k_cache = KVCache(num_heads/num_heads, config.cache_limit, base_name + "k_cache"); v_cache = KVCache(num_heads/num_heads, config.cache_limit, base_name + "v_cache"); } - if (config.do_mask) { - mask = Causalmask(base_name + "mask"); - } - softmax = Softmax(DIMENSION, base_name + "softmax"); + softmax = Softmax(DIMENSION, config.do_mask, base_name + "softmax"); softmax_scale = 1/std::sqrt(q_head_dim); } vector Forward(vector inputs, vector args) override { @@ -97,10 +93,7 @@ class DeepseekMultiHeadLatentAttention final : public Module { k = k.transpose(SEQUENCE, DIMENSION); auto qk = Tensor::mm(q, k); qk = qk * softmax_scale; - if (mask.ready()) { - qk = mask(qk); - } - qk = softmax(qk); + qk = softmax(qk, k_cache.getCacheSeqLen()); auto o = Tensor::mm(qk, v); o = o.view(-1, 1, -1, v_head_dim * num_heads); o = o_proj(o); diff --git a/src/models/fuyu/configuration_fuyu.hpp b/src/models/fuyu/configuration_fuyu.hpp index fc5ed223..e240d2c4 100644 --- a/src/models/fuyu/configuration_fuyu.hpp +++ b/src/models/fuyu/configuration_fuyu.hpp @@ -40,6 +40,8 @@ class FuyuConfig { int patch_size{}; int chl_size{}; int cache_limit{}; + float rope_theta; + int max_position_embeddings; FuyuNameConfig name_config; @@ -53,6 +55,8 @@ class FuyuConfig { block_num = 36; patch_size = 30; chl_size = 3; + max_position_embeddings= 16384; + rope_theta = 25000; } else { throw std::runtime_error("Unsupported model size"); } diff --git a/src/models/fuyu/modeling_fuyu.hpp b/src/models/fuyu/modeling_fuyu.hpp index b06b7279..90e26442 100644 --- a/src/models/fuyu/modeling_fuyu.hpp +++ b/src/models/fuyu/modeling_fuyu.hpp @@ -21,9 +21,9 @@ class PersimmonBlock final : public Module { public: PersimmonBlock() = default; - PersimmonBlock(int hidden_dim, int head_size, int ffn_hidden, int cache_limit, const FuyuNameConfig &names, const string &base_name) { + PersimmonBlock(int hidden_dim, int head_size, int ffn_hidden, float rope_theta, int max_position_embeddings, int cache_limit, const FuyuNameConfig &names, const string &base_name) { attention = MultiHeadAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, SPLIT_D_HD, true, false, - PERSIMMONROPE, cache_limit, true, true, names, base_name + names._attn_base_name); + PERSIMMONROPE, rope_theta, max_position_embeddings, cache_limit, true, true, names, base_name + names._attn_base_name); mlp = FeedForward(hidden_dim, ffn_hidden, "ReLU2", true, names, base_name + names._ffn_base_name); norm1 = LayerNorm(hidden_dim, true, 1e-6, base_name + names._attn_norm_name); @@ -47,8 +47,8 @@ class Persimmon final : public Module { public: Persimmon() = default; - Persimmon(int hidden_dim, int head_size, int ffn_hidden, int cache_limit, int block_num, int vocab_size, const FuyuNameConfig &names) { - blocks = List(block_num, hidden_dim, head_size, ffn_hidden, cache_limit, names, names.blk_name); + Persimmon(int hidden_dim, int head_size, int ffn_hidden, float rope_theta, int max_position_embeddings, int cache_limit, int block_num, int vocab_size, const FuyuNameConfig &names) { + blocks = List(block_num, hidden_dim, head_size, ffn_hidden, rope_theta, max_position_embeddings, cache_limit, names, names.blk_name); norm = LayerNorm(hidden_dim, true, 1e-6, names.post_norm_name); lm_head = Linear(hidden_dim, vocab_size, false, names.lm_head_name); } @@ -83,16 +83,18 @@ class FuyuModel final : public Module { public: explicit FuyuModel(const FuyuConfig &config) : FuyuModel(config.vocab_size, config.hidden_dim, config.head_size, config.ffn_hidden, config.block_num, + config.rope_theta, config.max_position_embeddings, config.cache_limit, config.patch_size, config.chl_size, config.name_config) { } FuyuModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, + float rope_theta, int max_position_embeddings, int cache_limit, int patch_size, int chl_size, const FuyuNameConfig &names) { embed_tokens = Embedding(vocab_size, hidden_dim, names.token_embd_name); vision_embed_tokens = Linear(patch_size * patch_size * chl_size, hidden_dim, true, names.vision_embed_tokens_name); fuyu_gather = FuyuGather("gather"); - persimmon = Persimmon(hidden_dim, head_size, ffn_hidden, cache_limit, block_num, vocab_size, names); + persimmon = Persimmon(hidden_dim, head_size, ffn_hidden, rope_theta, max_position_embeddings,cache_limit, block_num, vocab_size, names); } vector Forward(vector inputs, vector args) override { auto input_ids = embed_tokens(inputs[0]); diff --git a/src/models/gemma/configuration_gemma.hpp b/src/models/gemma/configuration_gemma.hpp index 6952688a..34d058ea 100644 --- a/src/models/gemma/configuration_gemma.hpp +++ b/src/models/gemma/configuration_gemma.hpp @@ -93,6 +93,7 @@ struct GemmaConfig { int intermediate_size = 16384; int head_dim = 256; float rms_norm_eps = 1e-6; + float rope_theta= 10000; int cache_limit; RoPEType RoPE_type = RoPEType::HFHUBROPE; diff --git a/src/models/gemma/modeling_gemma.hpp b/src/models/gemma/modeling_gemma.hpp index bf7a995d..0dcede29 100644 --- a/src/models/gemma/modeling_gemma.hpp +++ b/src/models/gemma/modeling_gemma.hpp @@ -17,6 +17,7 @@ #include "Module.hpp" #include "Tensor.hpp" #include "configuration_gemma.hpp" +#include "models/transformer/modeling_transformer.hpp" #include using namespace mllm; @@ -48,82 +49,13 @@ class GemmaMLP final : public Module { }; ///< gemma-2B use MQA while 7B use MHA -class GemmaAttention final : public Module { -public: - GemmaAttention() = default; - GemmaAttention(const GemmaConfig &config, const GemmaNameConfig &names, const string &base_name) { - hidden_size = config.hidden_size; - head_dim = config.head_dim; - num_heads = config.num_attention_heads; - num_key_value_heads = config.num_key_value_heads; - num_key_value_groups = num_heads / num_key_value_heads; - - // init layers - q_proj = Linear(hidden_size, num_heads * head_dim, false, base_name + names._q_proj_name); - k_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._k_proj_name); - v_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._v_proj_name); - o_proj = Linear(num_heads * head_dim, hidden_size, false, base_name + names._o_proj_name); - q_rope = RoPE(config.RoPE_type, base_name + "q_rope"); - k_rope = RoPE(config.RoPE_type, base_name + "k_rope"); - k_cache = KVCache(num_heads / num_key_value_heads, config.cache_limit, base_name + "k_cache"); - v_cache = KVCache(num_heads / num_key_value_heads, config.cache_limit, base_name + "v_cache"); - mask = Causalmask(base_name + "mask"); - softmax = Softmax(DIMENSION, base_name + "softmax"); - } - - std::vector Forward(std::vector inputs, std::vector args) override { - auto query_states = q_proj(inputs[0]); - auto key_states = k_proj(inputs[1]); - auto value_states = v_proj(inputs[2]); - - // [batch, heads, sequence, dims] - query_states = query_states.view(-1, num_heads, -1, head_dim); - key_states = key_states.view(-1, num_key_value_heads, -1, head_dim); - value_states = value_states.view(-1, num_key_value_heads, -1, head_dim); - - // embedding - query_states = q_rope(query_states); - key_states = k_rope(key_states); - - // kv cache - key_states = k_cache(key_states); - value_states = v_cache(value_states); - - // attention weight - auto atten_weight = Tensor::mm(query_states, key_states.transpose(Chl::SEQUENCE, Chl::DIMENSION)) / std::sqrt(head_dim); - atten_weight = mask(atten_weight); - atten_weight = softmax(atten_weight); - - // attention output - auto atten_output = Tensor::mm(atten_weight, value_states); - atten_output = atten_output.view(-1, 1, -1, head_dim * num_heads); - atten_output = o_proj(atten_output); - return {atten_output}; - } - -private: - int hidden_size; - int num_heads; - int head_dim; - int num_key_value_heads; - int num_key_value_groups; - Layer q_proj; - Layer k_proj; - Layer v_proj; - Layer o_proj; - Layer q_rope; - Layer k_rope; - Layer k_cache; - Layer v_cache; - Layer mask; - Layer softmax; -}; - class GemmaDecoder final : public Module { public: GemmaDecoder() = default; GemmaDecoder(const GemmaConfig &config, const GemmaNameConfig &names, const string &base_name) { - self_atten = GemmaAttention(config, names, base_name + names._attn_base_name); + self_atten = MultiHeadAttention(config.hidden_size, config.num_attention_heads, config.num_key_value_heads, + config.hidden_size / config.num_attention_heads, SPLIT_NONE, false, false, + config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, true, false, names, base_name + names._attn_base_name); mlp = GemmaMLP(config.hidden_size, config.intermediate_size, names, base_name + names._ffn_base_name); input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, true, base_name + names._attn_norm_name); post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, true, base_name + names._ffn_norm_name); @@ -140,7 +72,7 @@ class GemmaDecoder final : public Module { } private: - GemmaAttention self_atten; + MultiHeadAttention self_atten; GemmaMLP mlp; Layer input_layernorm; Layer post_attention_layernorm; diff --git a/src/models/imagebind/modeling_imagebind.hpp b/src/models/imagebind/modeling_imagebind.hpp index 02750a4e..5254a8ac 100644 --- a/src/models/imagebind/modeling_imagebind.hpp +++ b/src/models/imagebind/modeling_imagebind.hpp @@ -30,7 +30,7 @@ class EncoderBlock final : public Module { } attention = MultiHeadAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, SPLIT_HD, false, bias_kv_cat, - RoPEType::NONE, 0, do_mask, true, + RoPEType::NONE, -1,-1,0, do_mask, true, names, base_name + names._attn_base_name); ffn = FeedForward(hidden_dim, ffn_hidden, "GELU", true, names, base_name + names._ffn_base_name); diff --git a/src/models/llama/configuration_llama.hpp b/src/models/llama/configuration_llama.hpp index 326ae815..2b9b1143 100644 --- a/src/models/llama/configuration_llama.hpp +++ b/src/models/llama/configuration_llama.hpp @@ -66,11 +66,14 @@ class LLaMAConfig { int vocab_size{}; int hidden_dim{}; int head_size{}; + int num_key_value_heads{}; int ffn_hidden{}; int block_num{}; RoPEType RoPE_type; int cache_limit{}; LLaMANameConfig names_config; + float rope_theta; + int max_position_embeddings; explicit LLaMAConfig(int token_limit, string billions = "7B", RoPEType type = LLAMAROPE, int vocab = 32000) { names_config.init(type); @@ -78,9 +81,22 @@ class LLaMAConfig { if (billions == "7B" || billions == "7b") { hidden_dim = 4096; head_size = 32; + num_key_value_heads = 32; ffn_hidden = 11008; block_num = 32; - } else { + max_position_embeddings= 16384; + rope_theta = 10000; + } else if (billions == "6B" || billions == "6b") { + // Yi @https://arxiv.org/abs/2403.04652 + hidden_dim = 4096; + head_size = 32; + num_key_value_heads = 4; + ffn_hidden = 11008; + block_num = 32; + max_position_embeddings= 4096; + rope_theta = 5000000.0; + vocab_size = 64000; + }else { throw std::runtime_error("Unsupported model size"); } RoPE_type = type; diff --git a/src/models/llama/modeling_elastic_llama.hpp b/src/models/llama/modeling_elastic_llama.hpp index cf5da5e0..f358bd89 100644 --- a/src/models/llama/modeling_elastic_llama.hpp +++ b/src/models/llama/modeling_elastic_llama.hpp @@ -19,10 +19,9 @@ class ElasticMultiHeadAttention final : public Module { ElasticLinear v_proj; Layer q_rope; Layer k_rope; - Layer k_cache; - Layer v_cache; - Layer mask; - Layer softmax; + KVCache k_cache; + KVCache v_cache; + Softmax softmax; ElasticLinear o_proj; int head_size_{}; int kv_head_size_{}; @@ -48,10 +47,7 @@ class ElasticMultiHeadAttention final : public Module { k_cache = KVCache(head_size/kv_head_size, cache_limit, base_name + "k_cache"); v_cache = KVCache(head_size/kv_head_size, cache_limit, base_name + "v_cache"); } - if (do_mask) { - mask = Causalmask(base_name + "mask"); - } - softmax = Softmax(DIMENSION, base_name + "softmax"); + softmax = Softmax(DIMENSION, do_mask, base_name + "softmax"); o_proj = ElasticLinear(head_size * attn_hidden_dim, hidden_dim, bias, base_name + names._o_proj_name); } vector Forward(vector inputs, vector args) override { @@ -76,10 +72,11 @@ class ElasticMultiHeadAttention final : public Module { k = k.transpose(SEQUENCE, DIMENSION); auto qk = Tensor::mm(q, k); qk = qk / std::sqrt(activate_hidden_dim);//attn_hidden_dim_ - if (mask.ready()) { - qk = mask(qk); + if (k_cache.ready() && v_cache.ready()) { + qk = softmax(qk, k_cache.getCacheSeqLen()); + }else{ + qk = softmax(qk); } - qk = softmax(qk); auto o = Tensor::mm(qk, v); o = o.view(-1, 1, -1, activate_hidden_dim * head_size_); o = o_proj(o, activate_dim, -1); diff --git a/src/models/llama/modeling_llama.hpp b/src/models/llama/modeling_llama.hpp index 75bf3edd..c53d5fb1 100644 --- a/src/models/llama/modeling_llama.hpp +++ b/src/models/llama/modeling_llama.hpp @@ -44,9 +44,9 @@ class LLaMABlock final : public Module { public: LLaMABlock() = default; - LLaMABlock(int hidden_dim, int head_size, int ffn_hidden, RoPEType RoPE_type, int cache_limit, const LLaMANameConfig &names, const string &base_name) { - attention = MultiHeadAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, SPLIT_NONE, false, false, - RoPE_type, cache_limit, true, false, names, base_name + names._attn_base_name); + LLaMABlock(int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, const LLaMANameConfig &names, const string &base_name) { + attention = MultiHeadAttention(hidden_dim, head_size, kv_head_size, hidden_dim / head_size, SPLIT_NONE, false, false, + RoPE_type, rope_theta, max_position_embeddings, cache_limit, true, false, names, base_name + names._attn_base_name); mlp = LLaMAMLP(hidden_dim, ffn_hidden, names, base_name + names._ffn_base_name); norm1 = RMSNorm(hidden_dim, 1e-6, base_name + names._attn_norm_name); norm2 = RMSNorm(hidden_dim, 1e-6, base_name + names._ffn_norm_name); @@ -70,13 +70,14 @@ class LLaMAModel final : public Module { public: explicit LLaMAModel(const LLaMAConfig &config) : - LLaMAModel(config.vocab_size, config.hidden_dim, config.head_size, config.ffn_hidden, config.block_num, config.RoPE_type, config.cache_limit, + LLaMAModel(config.vocab_size, config.hidden_dim, config.head_size, config.num_key_value_heads, config.ffn_hidden, config.block_num, + config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, config.names_config, config.names_config.blk_name) { } - LLaMAModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, int cache_limit, + LLaMAModel(int vocab_size, int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, const LLaMANameConfig &names, const string &base_name) { embedding = Embedding(vocab_size, hidden_dim, names.token_embd_name); - blocks = List(block_num, hidden_dim, head_size, ffn_hidden, RoPE_type, cache_limit, names, base_name); + blocks = List(block_num, hidden_dim, head_size, kv_head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, names, base_name); norm = RMSNorm(hidden_dim, 1e-6, names.post_norm_name); lm_head = Linear(hidden_dim, vocab_size, false, names.lm_head_name); } diff --git a/src/models/llama/modeling_sparse_llama.hpp b/src/models/llama/modeling_sparse_llama.hpp index afc18432..c496eb47 100644 --- a/src/models/llama/modeling_sparse_llama.hpp +++ b/src/models/llama/modeling_sparse_llama.hpp @@ -49,9 +49,9 @@ class SparseLLaMABlock final : public Module { public: SparseLLaMABlock() = default; - SparseLLaMABlock(bool is_down_sparse, int hidden_dim, int head_size, int ffn_hidden, RoPEType RoPE_type, int cache_limit, const LLaMANameConfig &names, const string &base_name) { + SparseLLaMABlock(bool is_down_sparse, int hidden_dim, int head_size, int ffn_hidden, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, const LLaMANameConfig &names, const string &base_name) { attention = MultiHeadAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, SPLIT_NONE, false, false, - RoPE_type, cache_limit, true, false, names, base_name + names._attn_base_name); + RoPE_type, rope_theta, max_position_embeddings, cache_limit, true, false, names, base_name + names._attn_base_name); mlp = SparseLLaMAMLP(hidden_dim, ffn_hidden, names, base_name + names._ffn_base_name, is_down_sparse); norm1 = RMSNorm(hidden_dim, 1e-6, base_name + names._attn_norm_name); norm2 = RMSNorm(hidden_dim, 1e-6, base_name + names._ffn_norm_name); @@ -75,13 +75,15 @@ class SparseLLaMAModel final : public Module { public: explicit SparseLLaMAModel(const LLaMAConfig &config, bool is_down_sparse = false) : - SparseLLaMAModel(config.vocab_size, config.hidden_dim, config.head_size, config.ffn_hidden, config.block_num, config.RoPE_type, config.cache_limit, + SparseLLaMAModel(config.vocab_size, config.hidden_dim, config.head_size, config.ffn_hidden, config.block_num, config.RoPE_type, + config.rope_theta, config.max_position_embeddings, config.cache_limit, config.names_config, config.names_config.blk_name, is_down_sparse) { } - SparseLLaMAModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, int cache_limit, + SparseLLaMAModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, + float rope_theta, int max_position_embeddings, int cache_limit, const LLaMANameConfig &names, const string &base_name, bool is_down_sparse) { embedding = Embedding(vocab_size, hidden_dim, names.token_embd_name); - blocks = List(block_num, is_down_sparse, hidden_dim, head_size, ffn_hidden, RoPE_type, cache_limit, names, base_name); + blocks = List(block_num, is_down_sparse, hidden_dim, head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, names, base_name); norm = RMSNorm(hidden_dim, 1e-6, names.post_norm_name); lm_head = Linear(hidden_dim, vocab_size, false, names.lm_head_name); } diff --git a/src/models/llama/tokenization_llama.hpp b/src/models/llama/tokenization_llama.hpp index a0f504ca..ab6bca3d 100644 --- a/src/models/llama/tokenization_llama.hpp +++ b/src/models/llama/tokenization_llama.hpp @@ -18,27 +18,21 @@ class LLaMATokenizer final { if (scores.empty()) { throw std::invalid_argument("Input vector is empty"); } - unsigned int maxIndex = 0; - float maxValue = scores[0]; - for (size_t i = 1; i < scores.size(); ++i) { - if (scores[i] > maxValue) { - maxIndex = i; - maxValue = scores[i]; - } - } - return maxIndex; + return std::max_element(scores.begin(), scores.end()) - scores.begin(); } + bool bos_=true; public: - explicit LLaMATokenizer(const std::string &vocab_file) { + explicit LLaMATokenizer(const std::string &vocab_file, bool bos=true) { Module::initBackend(MLLM_CPU); tokenizer = new BPETokenizer(vocab_file); + bos_ = bos; } Tensor tokenize(std::string &text, int str_i = 0) const { if (text[0] != ' ') { text = ' ' + text; } auto tokens_id = vector(); - tokenizer->tokenize(text, tokens_id, true); + tokenizer->tokenize(text, tokens_id, bos_); if (str_i > 0){ tokens_id[0] = 13; } diff --git a/src/models/llava/modeling_llava.hpp b/src/models/llava/modeling_llava.hpp index 1b7d924f..4ddd4937 100644 --- a/src/models/llava/modeling_llava.hpp +++ b/src/models/llava/modeling_llava.hpp @@ -20,9 +20,9 @@ class LLaMABodyModel final : public Module { public: LLaMABodyModel() = default; - LLaMABodyModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, int cache_limit, + LLaMABodyModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, const LLaMANameConfig &names, const string &base_name) { - blocks = List(block_num, hidden_dim, head_size, ffn_hidden, RoPE_type, cache_limit, names, base_name); + blocks = List(block_num, hidden_dim, head_size, head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, names, base_name); norm = RMSNorm(hidden_dim, 1e-6, names.post_norm_name); lm_head = Linear(hidden_dim, vocab_size, false, names.lm_head_name); } @@ -116,17 +116,20 @@ class LLaVAModel final : public Module { public: explicit LLaVAModel(const LLaVAConfig &config) : - LLaVAModel(config.vocab_size, config.hidden_dim, config.head_size, config.ffn_hidden, config.block_num, config.RoPE_type, config.cache_limit, + LLaVAModel(config.vocab_size, config.hidden_dim, config.head_size, config.ffn_hidden, config.block_num, + config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, config.names_config, config.vision_hidden_dim, config.vision_head_size, config.vision_ffn_hidden, config.patch, config.img_hw, config.vision_block_num, config.vit_names_config) { } - LLaVAModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, int cache_limit, + LLaVAModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, + RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, const LLaMANameConfig &names_config, int vision_hidden_dim, int vision_head_size, int vision_ffn_hidden, int patch, int img_hw, int vision_block_num, const ViTNameConfig &vit_names_config) { text_embedding = Embedding(vocab_size, hidden_dim, names_config.token_embd_name); - llama_body = LLaMABodyModel(vocab_size, hidden_dim, head_size, ffn_hidden, block_num, RoPE_type, cache_limit, + llama_body = LLaMABodyModel(vocab_size, hidden_dim, head_size, ffn_hidden, block_num, + RoPE_type, rope_theta, max_position_embeddings, cache_limit, names_config, names_config.blk_name); vision_tower = LLaVAVisionModel(vision_hidden_dim, vision_head_size, vision_ffn_hidden, patch, img_hw, vision_block_num, vit_names_config, vit_names_config.vison_model_name); diff --git a/src/models/mistral/modeling_mistral.hpp b/src/models/mistral/modeling_mistral.hpp index f36921f6..e1a55e49 100644 --- a/src/models/mistral/modeling_mistral.hpp +++ b/src/models/mistral/modeling_mistral.hpp @@ -16,6 +16,7 @@ #include "Module.hpp" #include "Tensor.hpp" #include "configuration_mistral.hpp" +#include "models/transformer/modeling_transformer.hpp" #include using namespace mllm; @@ -46,82 +47,14 @@ class MistralMLP final : public Module { Layer silu; }; -class MistralAttention final : public Module { -public: - MistralAttention() = default; - MistralAttention(const MistralConfig &config, const MistralNameConfig &names, const string &base_name) { - hidden_size = config.hidden_size; - num_heads = config.num_attention_heads; - head_dim = config.hidden_size / num_heads; - num_key_value_heads = config.num_key_value_heads; - num_key_value_groups = num_heads / num_key_value_heads; - - // init layers - q_proj = Linear(hidden_size, num_heads * head_dim, false, base_name + names._q_proj_name); - k_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._k_proj_name); - v_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._v_proj_name); - o_proj = Linear(num_heads * head_dim, hidden_size, false, base_name + names._o_proj_name); - q_rope = RoPE(config.RoPE_type, config.rope_theta, config.max_position_embeddings, base_name + "q_rope"); - k_rope = RoPE(config.RoPE_type, config.rope_theta, config.max_position_embeddings, base_name + "k_rope"); - k_cache = KVCache(num_key_value_groups, config.cache_limit, base_name + "k_cache"); - v_cache = KVCache(num_key_value_groups, config.cache_limit, base_name + "v_cache"); - mask = Causalmask(base_name + "mask"); - softmax = Softmax(DIMENSION, base_name + "softmax"); - } - - std::vector Forward(std::vector inputs, std::vector args) override { - auto query_states = q_proj(inputs[0]); - auto key_states = k_proj(inputs[1]); - auto value_states = v_proj(inputs[2]); - - // [batch, heads, sequence, dims] - query_states = query_states.view(-1, num_heads, -1, head_dim); - key_states = key_states.view(-1, num_key_value_heads, -1, head_dim); - value_states = value_states.view(-1, num_key_value_heads, -1, head_dim); - - // embedding - query_states = q_rope(query_states); - key_states = k_rope(key_states); - - // kv cache - key_states = k_cache(key_states); - value_states = v_cache(value_states); - - // attention weight - auto atten_weight = Tensor::mm(query_states, key_states.transpose(Chl::SEQUENCE, Chl::DIMENSION)) / std::sqrt(head_dim); - atten_weight = mask(atten_weight); - atten_weight = softmax(atten_weight); - - // attention output - auto atten_output = Tensor::mm(atten_weight, value_states); - atten_output = atten_output.view(-1, 1, -1, head_dim * num_heads); - atten_output = o_proj(atten_output); - return {atten_output}; - } - -private: - int hidden_size; - int num_heads; - int head_dim; - int num_key_value_heads; - int num_key_value_groups; - Layer q_proj; - Layer k_proj; - Layer v_proj; - Layer o_proj; - Layer q_rope; - Layer k_rope; - Layer k_cache; - Layer v_cache; - Layer mask; - Layer softmax; -}; - class MistralDecoder final : public Module { public: MistralDecoder() = default; MistralDecoder(const MistralConfig &config, const MistralNameConfig &names, const string &base_name) { - self_atten = MistralAttention(config, names, base_name + names._attn_base_name); + self_atten = MultiHeadAttention(config.hidden_size, config.num_attention_heads, config.num_key_value_heads, + config.hidden_size / config.num_attention_heads, SPLIT_NONE, false, false, + config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, + true, false, names, base_name + names._attn_base_name); mlp = MistralMLP(config.hidden_size, config.intermediate_size, names, base_name + names._ffn_base_name); input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._ffn_norm_name); @@ -138,7 +71,7 @@ class MistralDecoder final : public Module { } private: - MistralAttention self_atten; + MultiHeadAttention self_atten; MistralMLP mlp; Layer input_layernorm; Layer post_attention_layernorm; diff --git a/src/models/opt/modeling_opt.hpp b/src/models/opt/modeling_opt.hpp index 16c20c56..b4027092 100644 --- a/src/models/opt/modeling_opt.hpp +++ b/src/models/opt/modeling_opt.hpp @@ -18,7 +18,7 @@ class OPTBlock final : public Module { OPTBlock() = default; OPTBlock(int hidden_dim, int head_size, int ffn_hidden, int cache_limit, const optNameConfig &names, const string &base_name) { attention = MultiHeadAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, SPLIT_NONE, false, false, - NONE, cache_limit, true, true, names, base_name + names._attn_base_name); + NONE, -1, -1, cache_limit, true, true, names, base_name + names._attn_base_name); mlp = FeedForward(hidden_dim, ffn_hidden, "ReLU", true, names, base_name + names._ffn_base_name); norm1 = LayerNorm(hidden_dim, true, 1e-05, base_name + names._attn_norm_name); diff --git a/src/models/qwen/modeling_qwen.hpp b/src/models/qwen/modeling_qwen.hpp index 370d1891..a4f100d5 100644 --- a/src/models/qwen/modeling_qwen.hpp +++ b/src/models/qwen/modeling_qwen.hpp @@ -92,8 +92,8 @@ class QWenAttention final : public Module { // attention weight auto atten_weight = Tensor::mm(query_states, key_states.transpose(Chl::SEQUENCE, Chl::DIMENSION)) / std::sqrt(head_dim); - atten_weight = mask(atten_weight); - atten_weight = softmax(atten_weight); + atten_weight = mask(atten_weight, k_cache.getCacheSeqLen()); + atten_weight = softmax(atten_weight, k_cache.getCacheSeqLen()); // attention output auto atten_output = Tensor::mm(atten_weight, value_states); @@ -114,10 +114,10 @@ class QWenAttention final : public Module { Layer o_proj; Layer q_rope; Layer k_rope; - Layer k_cache; - Layer v_cache; - Layer mask; - Layer softmax; + KVCache k_cache; + KVCache v_cache; + Causalmask mask; + Softmax softmax; }; // Copied from GemmaDecoder with Gemma->Qwen and set RmsNorm(without add_unit_offset) diff --git a/src/models/stablelm/modeling_stablelm.hpp b/src/models/stablelm/modeling_stablelm.hpp index 70f81324..2c834379 100644 --- a/src/models/stablelm/modeling_stablelm.hpp +++ b/src/models/stablelm/modeling_stablelm.hpp @@ -19,10 +19,9 @@ class StableLMMultiHeadAttention final : public Module { Layer k_rope; Layer q_norm; Layer k_norm; - Layer k_cache; - Layer v_cache; - Layer mask; - Layer softmax; + KVCache k_cache; + KVCache v_cache; + Softmax softmax; Layer o_proj; Parameter bias_k; Parameter bias_v; @@ -59,10 +58,7 @@ class StableLMMultiHeadAttention final : public Module { k_cache = KVCache(head_size / kv_head_size, cache_limit, base_name + "k_cache"); v_cache = KVCache(head_size / kv_head_size, cache_limit, base_name + "v_cache"); } - if (do_mask) { - mask = Causalmask(base_name + "mask"); - } - softmax = Softmax(DIMENSION, base_name + "softmax"); + softmax = Softmax(DIMENSION, do_mask, base_name + "softmax"); o_proj = Linear(head_size * attn_hidden_dim, hidden_dim, false, base_name + names._o_proj_name); if (bias_kv_cat) { bias_k = Parameter(1, 1, head_size, attn_hidden_dim, base_name + "bias_k"); @@ -104,10 +100,7 @@ class StableLMMultiHeadAttention final : public Module { k = k.transpose(SEQUENCE, DIMENSION); auto qk = Tensor::mm(q, k); qk = qk / std::sqrt(attn_hidden_dim_); - if (mask.ready()) { - qk = mask(qk); - } - qk = softmax(qk); + qk = softmax(qk, k_cache.getCacheSeqLen()); auto o = Tensor::mm(qk, v); o = o.view(-1, 1, -1, attn_hidden_dim_ * head_size_); o = o_proj(o); diff --git a/src/models/tinyllama/configuration_tinyllama.hpp b/src/models/tinyllama/configuration_tinyllama.hpp index eb4bf147..0ba07dd1 100644 --- a/src/models/tinyllama/configuration_tinyllama.hpp +++ b/src/models/tinyllama/configuration_tinyllama.hpp @@ -19,6 +19,8 @@ class TinyLLaMAConfig { RoPEType RoPE_type; int cache_limit{}; LLaMANameConfig names_config; + float rope_theta; + int max_position_embeddings; explicit TinyLLaMAConfig(int token_limit, string billions = "1.5B", RoPEType type = HFHUBROPE, int vocab = 32000) { names_config.init(type); @@ -29,6 +31,8 @@ class TinyLLaMAConfig { kv_head_size = 4; ffn_hidden = 5632; block_num = 22; + max_position_embeddings= 16384; + rope_theta = 10000; } else { throw std::runtime_error("Unsupported model size"); } diff --git a/src/models/tinyllama/modeling_tinyllama.hpp b/src/models/tinyllama/modeling_tinyllama.hpp index 9b2a3b98..3d54d6cf 100644 --- a/src/models/tinyllama/modeling_tinyllama.hpp +++ b/src/models/tinyllama/modeling_tinyllama.hpp @@ -20,9 +20,9 @@ class TinyLLaMABlock final : public Module { public: TinyLLaMABlock() = default; - TinyLLaMABlock(int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, RoPEType RoPE_type, int cache_limit, const LLaMANameConfig &names, const string &base_name) { + TinyLLaMABlock(int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, const LLaMANameConfig &names, const string &base_name) { attention = MultiHeadAttention(hidden_dim, head_size, kv_head_size, hidden_dim / head_size, SPLIT_NONE, false, false, - RoPE_type, cache_limit, true, false, names, base_name + names._attn_base_name); + RoPE_type, rope_theta, max_position_embeddings, cache_limit, true, false, names, base_name + names._attn_base_name); mlp = LLaMAMLP(hidden_dim, ffn_hidden, names, base_name + names._ffn_base_name); norm1 = RMSNorm(hidden_dim, 1e-6, base_name + names._attn_norm_name); norm2 = RMSNorm(hidden_dim, 1e-6, base_name + names._ffn_norm_name); @@ -46,13 +46,15 @@ class TinyLLaMAModel final : public Module { public: explicit TinyLLaMAModel(const TinyLLaMAConfig &config) : - TinyLLaMAModel(config.vocab_size, config.hidden_dim, config.head_size, config.kv_head_size, config.ffn_hidden, config.block_num, config.RoPE_type, config.cache_limit, + TinyLLaMAModel(config.vocab_size, config.hidden_dim, config.head_size, config.kv_head_size, config.ffn_hidden, config.block_num, + config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit, config.names_config, config.names_config.blk_name) { } - TinyLLaMAModel(int vocab_size, int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, int cache_limit, + TinyLLaMAModel(int vocab_size, int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, int block_num, + RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit, const LLaMANameConfig &names, const string &base_name) { embedding = Embedding(vocab_size, hidden_dim, names.token_embd_name); - blocks = List(block_num, hidden_dim, head_size, kv_head_size, ffn_hidden, RoPE_type, cache_limit, names, base_name); + blocks = List(block_num, hidden_dim, head_size, kv_head_size, ffn_hidden, RoPE_type, rope_theta, max_position_embeddings, cache_limit, names, base_name); norm = RMSNorm(hidden_dim, 1e-6, names.post_norm_name); lm_head = Linear(hidden_dim, vocab_size, false, names.lm_head_name); } diff --git a/src/models/transformer/modeling_transformer.hpp b/src/models/transformer/modeling_transformer.hpp index 7489baa5..9a1152a7 100644 --- a/src/models/transformer/modeling_transformer.hpp +++ b/src/models/transformer/modeling_transformer.hpp @@ -5,6 +5,7 @@ #ifndef MODELING_TRANSFORMER_HPP #define MODELING_TRANSFORMER_HPP +#include "Layer.hpp" #include "configuration_transformer.hpp" using namespace mllm; @@ -27,10 +28,9 @@ class MultiHeadAttention final : public Module { Layer k_rope; Layer q_norm; Layer k_norm; - Layer k_cache; - Layer v_cache; - Layer mask; - Layer softmax; + KVCache k_cache; + KVCache v_cache; + Softmax softmax; Layer o_proj; Parameter bias_k; Parameter bias_v; @@ -42,7 +42,8 @@ class MultiHeadAttention final : public Module { MultiHeadAttention() = default; MultiHeadAttention(int hidden_dim, int head_size,int kv_head_size, int attn_hidden_dim, AttnQKVSplitType do_qkv_proj, bool post_qkv_norm, bool bias_kv_cat, - RoPEType RoPE_type, int cache_limit, bool do_mask, bool bias, + RoPEType RoPE_type, float rope_theta, int max_position_embeddings, + int cache_limit, bool do_mask, bool bias, const TransformerNameConfig &names, const string &base_name) { attn_hidden_dim_ = attn_hidden_dim; head_size_ = head_size; @@ -60,17 +61,14 @@ class MultiHeadAttention final : public Module { k_norm = LayerNorm(attn_hidden_dim, true, 1e-6, base_name + names._k_norm_name); } if (RoPE_type > 0) { - q_rope = RoPE(RoPE_type, base_name + "q_rope"); - k_rope = RoPE(RoPE_type, base_name + "k_rope"); + q_rope = RoPE(RoPE_type, rope_theta, max_position_embeddings, base_name + "q_rope"); + k_rope = RoPE(RoPE_type, rope_theta, max_position_embeddings, base_name + "k_rope"); } if (cache_limit > 0) { k_cache = KVCache(head_size/kv_head_size, cache_limit, base_name + "k_cache"); v_cache = KVCache(head_size/kv_head_size, cache_limit, base_name + "v_cache"); } - if (do_mask) { - mask = Causalmask(base_name + "mask"); - } - softmax = Softmax(DIMENSION, base_name + "softmax"); + softmax = Softmax(DIMENSION, do_mask, base_name + "softmax"); o_proj = Linear(head_size * attn_hidden_dim, hidden_dim, bias, base_name + names._o_proj_name); if (bias_kv_cat) { bias_k = Parameter(1, 1, head_size, attn_hidden_dim, base_name + "bias_k"); @@ -112,10 +110,11 @@ class MultiHeadAttention final : public Module { k = k.transpose(SEQUENCE, DIMENSION); auto qk = Tensor::mm(q, k); qk = qk / std::sqrt(attn_hidden_dim_); - if (mask.ready()) { - qk = mask(qk); + if (k_cache.ready() && v_cache.ready()) { + qk = softmax(qk, k_cache.getCacheSeqLen()); + }else{ + qk = softmax(qk); } - qk = softmax(qk); auto o = Tensor::mm(qk, v); o = o.view(-1, 1, -1, attn_hidden_dim_ * head_size_); o = o_proj(o); diff --git a/src/models/vit/modeling_vit.hpp b/src/models/vit/modeling_vit.hpp index 03c89f70..a2e145ff 100644 --- a/src/models/vit/modeling_vit.hpp +++ b/src/models/vit/modeling_vit.hpp @@ -39,7 +39,7 @@ class ViTBlock final : public Module { ViTBlock() = default; ViTBlock(int hidden_dim, int head_size, int ffn_hidden, const string &act_fn_type, const ViTNameConfig &names, const string &base_name) { attention = MultiHeadAttention(hidden_dim, head_size, head_size, hidden_dim / head_size, SPLIT_NONE, false, false, - RoPEType::NONE, 0, false, true, names, base_name + names._attn_base_name); + RoPEType::NONE, -1,-1,0, false, true, names, base_name + names._attn_base_name); mlp = ViTMLP(hidden_dim, ffn_hidden, act_fn_type, names, base_name + names._ffn_base_name); down_proj = Linear(ffn_hidden, hidden_dim, true, base_name + names._down_proj_name); norm1 = LayerNorm(hidden_dim, true, 1e-6, base_name + names._attn_norm_name); diff --git a/src/models/yi/configuration_yi.hpp b/src/models/yi/configuration_yi.hpp deleted file mode 100644 index a31bbc53..00000000 --- a/src/models/yi/configuration_yi.hpp +++ /dev/null @@ -1,104 +0,0 @@ -/** - * @file configuration_Yi.hpp - * @author Chenghua Wang (chenghua.wang.edu@gmail.com) - * @brief - * @version 0.1 - * @date 2024-07-02 - * - * @copyright Copyright (c) 2024 - * - */ -#ifndef CONFIG_YI_HPP -#define CONFIG_YI_HPP -#include "models/transformer/configuration_transformer.hpp" - -using namespace mllm; - -class YiNameConfig : public TransformerNameConfig { -public: - std::string blk_name; - std::string token_embd_name; - std::string post_norm_name; - std::string lm_head_name; - std::string _gate_proj_name; - - void init(RoPEType type = LLAMAROPE) { - switch (type) { - case LLAMAROPE: { - blk_name = "layers."; - _attn_base_name = "attention."; - _ffn_base_name = "feed_forward."; - _q_proj_name = "wq"; - _k_proj_name = "wk"; - _v_proj_name = "wv"; - _o_proj_name = "wo"; - _gate_proj_name = "w1"; - _up_proj_name = "w3"; - _down_proj_name = "w2"; - _attn_norm_name = "attention_norm"; - _ffn_norm_name = "ffn_norm"; - token_embd_name = "tok_embeddings"; - post_norm_name = "norm"; - lm_head_name = "output"; - break; - } - case HFHUBROPE: { - blk_name = "model.layers."; - _attn_base_name = "self_attn."; - _ffn_base_name = "mlp."; - _q_proj_name = "q_proj"; - _k_proj_name = "k_proj"; - _v_proj_name = "v_proj"; - _o_proj_name = "o_proj"; - _gate_proj_name = "gate_proj"; - _up_proj_name = "up_proj"; - _down_proj_name = "down_proj"; - _attn_norm_name = "input_layernorm"; - _ffn_norm_name = "post_attention_layernorm"; - token_embd_name = "model.embed_tokens"; - post_norm_name = "model.norm"; - lm_head_name = "lm_head"; - break; - } - default: { - throw std::runtime_error("Unsupported llama type"); - } - } - } -}; - -class YiConfig { -public: - explicit YiConfig(int token_limit, string billions = "6B", RoPEType type = LLAMAROPE, int vocab = 64000) { - names_config.init(type); - vocab_size = vocab; - if (!(billions == "6B" || billions == "6b")) { - throw std::runtime_error("Unsupported model size"); - } - RoPE_type = type; - cache_limit = token_limit; - } - -public: - bool attention_bias = false; - float attention_drop = 0.0; - int pad_token_id = 0; - int bos_token_id = 1; - int eos_token_id = 2; - int hidden_size = 4096; - float initializer_range = 0.02; - int intermediate_size = 11008; - int max_position_embeddings = 4096; - int num_attention_heads = 32; - int num_hidden_layers = 32; - int num_key_value_heads = 4; - int pretraining_tp = 1; - float rms_norm_eps = 1e-6; - float rope_theta = 5000000.0; - int vocab_size = 64000; - int cache_limit; - RoPEType RoPE_type; - YiNameConfig names_config; -}; - -#endif //! CONFIG_YI_HPP \ No newline at end of file diff --git a/src/models/yi/modeling_yi.hpp b/src/models/yi/modeling_yi.hpp deleted file mode 100644 index 699955dd..00000000 --- a/src/models/yi/modeling_yi.hpp +++ /dev/null @@ -1,195 +0,0 @@ -/** - * @file modeling_Yi.hpp - * @author Chenghua Wang (chenghua.wang.edu@gmail.com) - * @brief - * @version 0.1 - * @date 2024-07-02 - * - * @copyright Copyright (c) 2024 - * - */ -#ifndef MODELING_YI_HPP -#define MODELING_YI_HPP - -#include "Backend.hpp" -#include "Layer.hpp" -#include "Module.hpp" -#include "Tensor.hpp" -#include "configuration_yi.hpp" -#include -using namespace mllm; - -class YiMLP final : public Module { -public: - YiMLP() = default; - YiMLP(int hidden_size, int intermediate_size, const YiNameConfig &names, const std::string &base_name) { - gate_proj = Linear(hidden_size, intermediate_size, false, base_name + names._gate_proj_name); - silu = SiLU(base_name + "act"); - up_proj = Linear(hidden_size, intermediate_size, false, base_name + names._up_proj_name); - down_proj = Linear(intermediate_size, hidden_size, false, base_name + names._down_proj_name); - } - - std::vector Forward(std::vector inputs, std::vector args) override { - auto x = gate_proj(inputs[0]); - x = silu(x); - auto y = up_proj(inputs[0]); - x = x * y; - x = down_proj(x); - return {x}; - } - -private: - Layer gate_proj; - Layer up_proj; - Layer down_proj; - - Layer silu; -}; - -class YiAttention final : public Module { -public: - YiAttention() = default; - YiAttention(const YiConfig &config, const YiNameConfig &names, const string &base_name) { - hidden_size = config.hidden_size; - num_heads = config.num_attention_heads; - head_dim = config.hidden_size / num_heads; - num_key_value_heads = config.num_key_value_heads; - num_key_value_groups = num_heads / num_key_value_heads; - - // init layers - q_proj = Linear(hidden_size, num_heads * head_dim, false, base_name + names._q_proj_name); - k_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._k_proj_name); - v_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._v_proj_name); - o_proj = Linear(num_heads * head_dim, hidden_size, false, base_name + names._o_proj_name); - q_rope = RoPE(config.RoPE_type, config.rope_theta, config.max_position_embeddings, base_name + "q_rope"); - k_rope = RoPE(config.RoPE_type, config.rope_theta, config.max_position_embeddings, base_name + "k_rope"); - k_cache = KVCache(num_key_value_groups, config.cache_limit, base_name + "k_cache"); - v_cache = KVCache(num_key_value_groups, config.cache_limit, base_name + "v_cache"); - mask = Causalmask(base_name + "mask"); - softmax = Softmax(DIMENSION, base_name + "softmax"); - } - - std::vector Forward(std::vector inputs, std::vector args) override { - auto query_states = q_proj(inputs[0]); - auto key_states = k_proj(inputs[1]); - auto value_states = v_proj(inputs[2]); - - // [batch, heads, sequence, dims] - query_states = query_states.view(-1, num_heads, -1, head_dim); - key_states = key_states.view(-1, num_key_value_heads, -1, head_dim); - value_states = value_states.view(-1, num_key_value_heads, -1, head_dim); - - // embedding - query_states = q_rope(query_states); - key_states = k_rope(key_states); - - // kv cache - key_states = k_cache(key_states); - value_states = v_cache(value_states); - - // attention weight - auto atten_weight = Tensor::mm(query_states, key_states.transpose(Chl::SEQUENCE, Chl::DIMENSION)) / std::sqrt(head_dim); - atten_weight = mask(atten_weight); - atten_weight = softmax(atten_weight); - - // attention output - auto atten_output = Tensor::mm(atten_weight, value_states); - atten_output = atten_output.view(-1, 1, -1, head_dim * num_heads); - atten_output = o_proj(atten_output); - return {atten_output}; - } - -private: - int hidden_size; - int num_heads; - int head_dim; - int num_key_value_heads; - int num_key_value_groups; - Layer q_proj; - Layer k_proj; - Layer v_proj; - Layer o_proj; - Layer q_rope; - Layer k_rope; - Layer k_cache; - Layer v_cache; - Layer mask; - Layer softmax; -}; - -class YiDecoder final : public Module { -public: - YiDecoder() = default; - YiDecoder(const YiConfig &config, const YiNameConfig &names, const string &base_name) { - self_atten = YiAttention(config, names, base_name + names._attn_base_name); - mlp = YiMLP(config.hidden_size, config.intermediate_size, names, base_name + names._ffn_base_name); - input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); - post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._ffn_norm_name); - } - - std::vector Forward(std::vector inputs, std::vector args) override { - auto x = input_layernorm(inputs[0]); - x = self_atten({x, x, x})[0]; - auto tmp = x + inputs[0]; - x = post_attention_layernorm(tmp); - x = mlp({x})[0]; - x = x + tmp; - return {x}; - } - -private: - YiAttention self_atten; - YiMLP mlp; - Layer input_layernorm; - Layer post_attention_layernorm; -}; - -class YiModel final : public Module { -public: - YiModel() = default; - YiModel(const YiConfig &config, const YiNameConfig &names, const string &base_name) { - blocks = List(config.num_hidden_layers, config, names, base_name); - norm = RMSNorm(config.hidden_size, config.rms_norm_eps, names.post_norm_name); - } - - std::vector Forward(std::vector inputs, std::vector args) override { - auto x = inputs[0]; - for (auto &block : blocks) { - x = block({x})[0]; - } - x = norm(x); - return {x}; - } - -private: - std::vector blocks; - Layer norm; -}; - -class YiForCausalLM final : public Module { -public: - YiForCausalLM(YiConfig &config) { - auto names = config.names_config; - hidden_size = config.hidden_size; - embedding = Embedding(config.vocab_size, config.hidden_size, names.token_embd_name); - model = YiModel(config, names, names.blk_name); - lm_head = Linear(hidden_size, config.vocab_size, false, names.lm_head_name); - } - - std::vector Forward(std::vector inputs, std::vector args) override { - auto x = embedding(inputs[0]); - - // go through model - auto outputs = model({x})[0]; - outputs = lm_head(outputs); - return {outputs}; - } - -private: - int hidden_size; - Layer embedding; - Layer lm_head; - YiModel model; -}; - -#endif //! MODELING_YI_HPP \ No newline at end of file diff --git a/src/models/yi/tokenization_yi.hpp b/src/models/yi/tokenization_yi.hpp deleted file mode 100644 index 5fa657c1..00000000 --- a/src/models/yi/tokenization_yi.hpp +++ /dev/null @@ -1,63 +0,0 @@ -/** - * @file tokenization_Yi.hpp - * @author Chenghua Wang (chenghua.wang.edu@gmail.com) - * @brief - * @version 0.1 - * @date 2024-07-02 - * - * @copyright Copyright (c) 2024 - * - */ -#ifndef TOKENIZATION_YI_HPP -#define TOKENIZATION_YI_HPP - -#include "tokenizers/BPE/Bpe.hpp" - -using namespace mllm; - -class YiTokenizer final { - BPETokenizer *tokenizer; - - unsigned int argmax(const std::vector &scores) { - if (scores.empty()) { - throw std::invalid_argument("Input vector is empty"); - } - return std::max_element(scores.begin(), scores.end()) - scores.begin(); - } - -public: - explicit YiTokenizer(const std::string &vocab_file) { - Module::initBackend(MLLM_CPU); - tokenizer = new BPETokenizer(vocab_file); - } - - Tensor tokenize(std::string &text, int str_i = 0) const { - if (text[0] != ' ') { - text = ' ' + text; - } - auto tokens_id = vector(); - tokenizer->tokenize(text, tokens_id, false); - if (str_i > 0) { - tokens_id[0] = 13; - } - return BPETokenizer::tokens2Input(tokens_id); - } - - std::string detokenize(const std::vector &tokens) { - return tokenizer->detokenize(tokens); - } - - std::pair detokenize(Tensor &result) { - assert(result.batch() == 1); - assert(result.head() == 1); - vector scores; - for (int i = 0; i < result.dimension(); ++i) { - auto value = result.dataAt(0, 0, result.sequence() - 1, i); - scores.push_back(value); - } - auto token_idx = this->argmax(scores); - return {tokenizer->detokenize({token_idx}), token_idx}; - } -}; - -#endif // !TOKENIZATION_YI_HPP diff --git a/src/quantizer/QuantWriter.cpp b/src/quantizer/QuantWriter.cpp index 6383861a..2b01d09d 100644 --- a/src/quantizer/QuantWriter.cpp +++ b/src/quantizer/QuantWriter.cpp @@ -1,10 +1,12 @@ #include "ParamWriter.hpp" #include "ParamLoader.hpp" +#include "Types.hpp" #include "backends/cpu/quantize/QuantizeQ4.hpp" #include "backends/cpu/quantize/QuantizeQ8.hpp" #include #include "QuantWriter.hpp" #include "backends/cpu/quantize/QuantizeQ6.hpp" +#include "backends/cpu/compute/GEMM_AArch64.hpp" namespace mllm { QuantWriter::QuantWriter(std::string output_path, std::string input_path) : ParamWriter(output_path), output_path_(output_path) { @@ -39,6 +41,8 @@ vector fp32_layers = {"norm", "rope", "bias","rotary_emb", "embed_tokens "embeddings", "logit_scale", //, "tok_embeddings"}; "modality_preprocessors", "modality_heads", "modality_postprocessors", "pre_transformer_layer"}; vector q6_layers = {"w2", "wv", "dense_h_to_4h", "v_proj", "down_proj"}; +vector q4x4_2_q4_layers = {"w2", "wv", "dense_h_to_4h", "v_proj", "down_proj"}; +vector q4x4_2_q4_layers_ = {"wv", "v_proj"}; int tmp_hidden_dim = -1; @@ -65,7 +69,7 @@ void QuantWriter::quantParams(DataType dataType) { } void *quant_ptr = nullptr; std::pair block_t; - if (find_names(name, q6_layers)) { + if (find_names(name, q6_layers) && (dataType== MLLM_TYPE_Q6_K ||dataType == MLLM_TYPE_Q4_K)) { if(tmp_hidden_dim>0 && (size/tmp_hidden_dim)%256!=0){ std::cout << "Quantize param " << name << " to " << DataTypeName(MLLM_TYPE_F32) << "\t"; const auto s = param_loader_->offsets_[name].second / sizeof(float); @@ -81,13 +85,80 @@ void QuantWriter::quantParams(DataType dataType) { const auto tsize = alloc_quant_block(s, MLLM_TYPE_F32).second; writeParam(name, MLLM_TYPE_F32, param, tsize); std::cout << " size:" << tsize << std::endl; - }else if (find_names(name, q6_layers)) { + }else if (find_names(name, q4x4_2_q4_layers) && dataType != MLLM_TYPE_Q6_K) { + // std::cout<<"q4x4_2_q4_layers"<offsets_[name].second / sizeof(float); + if(find_names(name, {"norm"})) { + tmp_hidden_dim = size; + } + } + quant_type_ = dataType; + for (const auto &name : param_names_) { + // int force_quant_type = -1; + auto *param = getParam(name); + if (param == nullptr) { + __exit(-1); + } + auto size = param_loader_->offsets_[name].second / sizeof(float); + if(find_names(name, {"norm"})) { + tmp_hidden_dim = size; + } + void *quant_ptr = nullptr; + std::pair block_t; + if(find_names(name, fp32_layers)) { + std::cout << "Quantize param " << name << " to " << DataTypeName(MLLM_TYPE_F32) << "\t"; + const auto s = param_loader_->offsets_[name].second / sizeof(float); + const auto tsize = alloc_quant_block(s, MLLM_TYPE_F32).second; + writeParam(name, MLLM_TYPE_F32, param, tsize); + std::cout << " size:" << tsize << std::endl; + }else if (find_names(name, q4x4_2_q4_layers_)) { + // std::cout<<"q4x4_2_q4_layers"< data_; diff --git a/src/quantizer/main.cpp b/src/quantizer/main.cpp index 448d2872..3889c3d7 100644 --- a/src/quantizer/main.cpp +++ b/src/quantizer/main.cpp @@ -34,6 +34,8 @@ int main(int argc, char **argv) { quant_writer.quantParams(MLLM_TYPE_Q6_K); } else if (quant_type == "Q8_K") { quant_writer.quantParams(MLLM_TYPE_Q8_K); + } else if (quant_type == "Q4_0_4_4") { + quant_writer.quantParams_q4_(MLLM_TYPE_Q4_0_4_4); } else { std::cout << "Quant type " << quant_type << " is not supported\n"; return -1; diff --git a/test/cpu/CPUSoftMaxTest.cpp b/test/cpu/CPUSoftMaxTest.cpp index 4cebf2bd..24b0cffd 100644 --- a/test/cpu/CPUSoftMaxTest.cpp +++ b/test/cpu/CPUSoftMaxTest.cpp @@ -4,7 +4,7 @@ #include "CPUTest.hpp" #include "backends/cpu/CPUSoftMax.hpp" TEST_F(CPUTest, CPUSoftMax1) { - SETUP_OP(CPUSoftMax, DIMENSION, 4); + SETUP_OP(CPUSoftMax, DIMENSION, false, 4); TENSOR(input0); TENSOR(output); TENSOR(c_output); diff --git a/tools/jni/modeling_fuyu.hpp b/tools/jni/modeling_fuyu.hpp index e2905061..cf6ac813 100644 --- a/tools/jni/modeling_fuyu.hpp +++ b/tools/jni/modeling_fuyu.hpp @@ -23,8 +23,8 @@ inline NetTensor *Attention_Fuyu(Context *ctx, NetTensor * x, int embedding_size v = _KVCache( {v}, 700, name + ".v_cache"); auto *qk = _Matmul( {q, k}, false, true, name + ".qk"); qk = _Scale( {qk}, 1.0F / std::sqrt(head_size), 0.0F, false, name + ".scale"); - qk = _Causalmask( {qk}, name + ".mask"); - qk = _Softmax( {qk}, DIMENSION, name + ".softmax"); + // qk = _Causalmask( {qk}, name + ".mask"); + qk = _Softmax( {qk}, DIMENSION, true, name + ".softmax"); auto *o = _Matmul( {qk, v}, false, false, name + ".qkv"); o = o->view(-1, 1, -1, hidden_size * head_size); o = _Linear( {o}, hidden_size * head_size, embedding_size, true, name + ".dense"); diff --git a/tools/jni/modeling_llama.hpp b/tools/jni/modeling_llama.hpp index 4df3a2b8..853892b9 100644 --- a/tools/jni/modeling_llama.hpp +++ b/tools/jni/modeling_llama.hpp @@ -19,8 +19,8 @@ inline NetTensor *Attention_LLAMA(Context *ctx, NetTensor *x, int embedding_size v = _KVCache( {v}, 500, name + ".v_cache"); auto *qk = _Matmul( {q, k}, false, true, name + ".qk"); qk = _Scale( {qk}, 1.0F / std::sqrt(hidden_size), 0.0F, false, name + ".scale"); - qk = _Causalmask( {qk}, name + ".mask"); - qk = _Softmax( {qk}, DIMENSION, name + ".softmax"); + // qk = _Causalmask( {qk}, name + ".mask"); + qk = _Softmax( {qk}, DIMENSION, true, name + ".softmax"); auto *o = _Matmul( {qk, v}, false, false, name + ".qkv"); o = o->view(-1, 1, -1, hidden_size * head_size); o = _Linear( {o}, hidden_size * head_size, embedding_size, false, name + ".wo");