diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index f2bc8ca7685..abf436adac4 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -1481,6 +1481,9 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, { diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 62246c10dab..a3b84a6a82e 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -1704,10 +1704,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { } } - LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); llama_kv_cache * kv_self = static_cast(memory.get()); - kv_self->state_write(io); + if (kv_self != nullptr) { + LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); + kv_self->state_write(io); + } return io.n_bytes(); } diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index 3dcad65bb6a..265db2527c7 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -441,6 +441,13 @@ void llama_kv_cache_unified::defrag_sched(float thold) { void llama_kv_cache_unified::set_full() { n = size; + + // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not + // affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views. + // we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so + // setting it to 0 is the simplest way to achieve that + // ref: https://github.com/ggml-org/llama.cpp/issues/13359 + head = 0; } llama_sbatch llama_kv_cache_unified::sbatch_init( @@ -1712,6 +1719,7 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) { void llama_kv_cache_recurrent::set_full() { n = size; + head = 0; } llama_sbatch llama_kv_cache_recurrent::sbatch_init( diff --git a/examples/talk-llama/llama-kv-cache.h b/examples/talk-llama/llama-kv-cache.h index bf3b4b6a443..e83e12c09f2 100644 --- a/examples/talk-llama/llama-kv-cache.h +++ b/examples/talk-llama/llama-kv-cache.h @@ -171,11 +171,8 @@ class llama_kv_cache_unified : public llama_kv_cache { void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - // Note: The value of head isn't only used to optimize searching - // for a free KV slot. llama_decode_impl also uses it, so it - // cannot be freely changed after a slot has been allocated. - uint32_t head = 0; - uint32_t size = 0; + uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) + uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id) // computed before each graph build @@ -343,11 +340,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache { void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - // Note: The value of head isn't only used to optimize searching - // for a free KV slot. llama_decode_impl also uses it, so it - // cannot be freely changed after a slot has been allocated. - uint32_t head = 0; - uint32_t size = 0; + uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) + uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id) // computed before each graph build diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index 4cce51668b4..ddb1b03675b 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -469,7 +469,7 @@ llama_model_loader::llama_model_loader( meta.reset(gguf_init_from_file(fname.c_str(), params)); if (!meta) { - throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str())); + throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str())); } get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); @@ -528,7 +528,7 @@ llama_model_loader::llama_model_loader( }; gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) }; if (!ctx_gguf) { - throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname_split)); + throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split)); } // check idx @@ -822,13 +822,18 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps mappings.reserve(files.size()); mmaps_used.reserve(files.size()); for (const auto & file : files) { - auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); - if (!reg) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); + bool is_numa = false; + + auto * dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (dev) { + auto * reg = ggml_backend_dev_backend_reg(dev); + auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); + if (is_numa_fn) { + is_numa = is_numa_fn(); + } } - auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); - std::unique_ptr mapping = std::make_unique(file.get(), prefetch ? -1 : 0, is_numa_fn()); + std::unique_ptr mapping = std::make_unique(file.get(), prefetch ? -1 : 0, is_numa); mmaps_used.emplace_back(mapping->size(), 0); if (mlock_mmaps) { std::unique_ptr mlock_mmap(new llama_mlock()); diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 3a4e72a36b0..7fd094b63f2 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -1389,6 +1389,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { // Add additional layer/vocab/etc checks here for other model sizes default: type = LLM_TYPE_UNKNOWN; } + + // For Granite MoE Shared + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); } break; case LLM_ARCH_CHAMELEON: { @@ -1772,6 +1775,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } } } } break; @@ -4385,10 +4395,13 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } - if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE) { + if (arch == LLM_ARCH_MINICPM || + arch == LLM_ARCH_GRANITE || + arch == LLM_ARCH_GRANITE_MOE) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } if (arch == LLM_ARCH_BAILINGMOE) { @@ -4598,11 +4611,6 @@ struct llm_build_llama : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); @@ -4674,11 +4682,6 @@ struct llm_build_llama : public llm_graph_context { cb(cur, "ffn_moe_out", il); } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -4701,11 +4704,6 @@ struct llm_build_llama : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); - // For Granite architecture - if (hparams.f_logit_scale) { - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); - } - cb(cur, "result_output", -1); res->t_logits = cur; @@ -4816,11 +4814,6 @@ struct llm_build_deci : public llm_graph_context { continue; } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - // modified to support attention-free layer of Llama-3_1-Nemotron-51B ggml_tensor * ffn_inp = cur; if (n_head > 0) { @@ -4844,11 +4837,6 @@ struct llm_build_deci : public llm_graph_context { cb(cur, "ffn_out", il); } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -4871,11 +4859,6 @@ struct llm_build_deci : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); - // For Granite architecture - if (hparams.f_logit_scale) { - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); - } - cb(cur, "result_output", -1); res->t_logits = cur; @@ -12214,6 +12197,194 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { } }; + +struct llm_build_granite : public llm_graph_context { + llm_build_granite( + const llama_model & model, + const llm_graph_params & params, + ggml_cgraph * gf, + const bool use_rope = true) + : llm_graph_context(params) { + + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - built only if rope enabled + ggml_tensor * inp_pos = nullptr; + if (use_rope) { + inp_pos = build_inp_pos(); + } + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and (optionally) RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (use_rope) { + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // For Granite architectures - scale residual + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + } + + // For Granite architectures - scale residual + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // For Granite architectures - scale logits + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + // ref: https://github.com/facebookresearch/chameleon // based on the original build_llama() function, changes: // * qk-norm @@ -12921,8 +13092,6 @@ llm_graph_result_ptr llama_model::build_graph( case LLM_ARCH_LLAMA: case LLM_ARCH_LLAMA4: case LLM_ARCH_MINICPM: - case LLM_ARCH_GRANITE: - case LLM_ARCH_GRANITE_MOE: { llm = std::make_unique(*this, params, gf); } break; @@ -13153,6 +13322,11 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_CHAMELEON: { llm = std::make_unique(*this, params, gf); diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 820d5128e29..159b1307a4c 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -14,6 +14,12 @@ #include #include +// Quantization types. Changes to this struct must be replicated in quantize.cpp +struct tensor_quantization { + std::string name; + ggml_type quant = GGML_TYPE_COUNT; +}; + static void zeros(std::ofstream & file, size_t n) { char zero = 0; for (size_t i = 0; i < n; ++i) { @@ -48,12 +54,6 @@ struct quantize_state_impl { {} }; -// changes to this struct must be replicated in quantize.cpp -struct tensor_quantization { - std::string name; - ggml_type quant = GGML_TYPE_COUNT; -}; - static void llama_tensor_dequantize_impl( ggml_tensor * tensor, std::vector> & output, std::vector & workers, const size_t nelements, const int nthread @@ -796,17 +796,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // unless the user specifies a type if (params->tensor_types) { const std::vector & tensor_types = *static_cast *>(params->tensor_types); + const std::string tensor_name(tensor->name); for (const auto & [tname, qtype] : tensor_types) { - if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) { - if (qtype != new_type) { - LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype)); + if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { + if (qtype != new_type) { + LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); + new_type = qtype; + break; // if two or more types are specified for the tensor, first match wins } - new_type = qtype; - break; } } } } + if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { new_type = params->token_embedding_type; } diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index 9fdddf7b071..2f06e0f8ce1 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -140,6 +140,11 @@ static struct llama_model * llama_model_load_from_file_impl( struct llama_model_params params) { ggml_time_init(); + if (!params.vocab_only && ggml_backend_reg_count() == 0) { + LLAMA_LOG_ERROR("%s: no backends are loaded. hint: use ggml_backend_load() or ggml_backend_load_all() to load a backend before calling this function\n", __func__); + return nullptr; + } + unsigned cur_percentage = 0; if (params.progress_callback == NULL) { params.progress_callback_user_data = &cur_percentage; diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index a8300e16d87..4746d5cb76c 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -193,6 +193,7 @@ option(GGML_RPC "ggml: use RPC" option(GGML_SYCL "ggml: use SYCL" OFF) option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF) option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON) +option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON) set (GGML_SYCL_TARGET "INTEL" CACHE STRING "ggml: sycl target device") set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING diff --git a/ggml/include/ggml-opt.h b/ggml/include/ggml-opt.h index da0c24b46fe..74ec080a055 100644 --- a/ggml/include/ggml-opt.h +++ b/ggml/include/ggml-opt.h @@ -128,6 +128,8 @@ extern "C" { // set gradients to zero, initilize loss, and optionally reset the optimizer GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer); + GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically + // get underlying tensors that store data // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 67c0223c010..cbf9783b744 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -65,6 +65,7 @@ #include #include #include +#include #include #include @@ -2587,3 +2588,149 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_cann_release_resources(ctx, acl_src, acl_dst, alpha); } + +/** + * @brief Performs expert-specific matrix multiplication (MoE) with + * floating-point precision using the CANN backend. + * + * This function executes a matrix multiplication operation tailored for + * Mixture of Experts (MoE) models, where the input tensor is multiplied + * with expert-specific weight matrices. It uses the CANN backend for + * efficient computation and stores the result in the destination tensor `dst`. + * The operation may leverage identity-based optimizations or routing masks + * as part of sparse expert selection. + * + * @param ctx The context for executing CANN backend operations. + * @param dst The destination tensor where the MoE multiplication result + * will be stored. + * + * @note This function assumes floating-point data types and is designed for + * MoE architectures, possibly involving sparse expert routing. + */ +static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + //dst [M, K, N, 1] + ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] + ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 + ggml_tensor * ids = dst->src[2]; //ids [K, N] + + GGML_TENSOR_BINARY_OP_LOCALS + + // copy index from npu to cpu + int64_t n_as = ne02; // A + int64_t n_ids = ids->ne[0]; // K + + std::vector ids_host(ggml_nbytes(ids)); + ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids), + ACL_MEMCPY_DEVICE_TO_HOST); + ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + + char * src0_original = (char *) src0->data; + char * src1_original = (char *) src1->data; + char * dst_original = (char *) dst->data; + size_t ori_src0_nb[4] = {nb00, nb01, nb02, nb03}; + + // src0 is F16, src1 is F32, dst is F32 + ggml_cann_pool_alloc src0_cast_allocator; + if (src0->type == GGML_TYPE_F16) { + src0_cast_allocator.alloc(ctx.pool(), sizeof(float) * ggml_nelements(src0)); + void* src0_cast_buf = src0_cast_allocator.get(); + + size_t cast_nb[GGML_MAX_DIMS]; + cast_nb[0] = sizeof(float_t); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + cast_nb[i] = cast_nb[i - 1] * src0->ne[i - 1]; + } + + aclTensor* acl_src0_f16 = ggml_cann_create_tensor(src0); + aclTensor* acl_cast = ggml_cann_create_tensor(src0_cast_buf, + ACL_FLOAT, sizeof(float), src0->ne, cast_nb, 4); + GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src0_f16, ACL_FLOAT, acl_cast); + ggml_cann_release_resources(ctx, acl_cast, acl_src0_f16); + + src0_original = (char *) src0_cast_buf; + memcpy(ori_src0_nb, cast_nb, sizeof(ori_src0_nb)); + } + + std::vector src0_tensor_vec; + std::vector src1_tensor_vec; + std::vector dst_tensor_vec; + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + // src0_row [M, D] -> weight && permute + int64_t src0_ne[2] = {ne01, ne00}; + size_t src0_nb[2] = {ori_src0_nb[1], ori_src0_nb[0]}; + // src1_row [D, 1] -> input + int64_t src1_ne[2] = {ne10, 1}; + size_t src1_nb[2] = {nb10, nb11}; + // dst_row [M, 1] -> out + int64_t dst_ne[2] = {ne0, 1}; + size_t dst_nb[2] = {nb0, nb1}; + + // expert index + int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + // If B = 1 (broadcast), always use 0; otherwise, use id. + int64_t i11 = (ne11 == 1 ? 0 : id); + int64_t i12 = iid1; + + int64_t i1 = id; + int64_t i2 = i12; + + void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2]; + void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12; + void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2; + + aclTensor* acl_src0 = ggml_cann_create_tensor(src0_tmp_ptr, + ACL_FLOAT, sizeof(float), + src0_ne, src0_nb, 2); + aclTensor* acl_src1 = ggml_cann_create_tensor(src1_tmp_ptr, + ACL_FLOAT, sizeof(float), + src1_ne, src1_nb, 2); + aclTensor* acl_dst = ggml_cann_create_tensor(dst_tmp_ptr, + ACL_FLOAT, sizeof(float), + dst_ne, dst_nb, 2); + + src0_tensor_vec.push_back(acl_src0); + src1_tensor_vec.push_back(acl_src1); + dst_tensor_vec.push_back(acl_dst); + } + } + + // GroupedMatmulV2 required tensor_list.size < 128 + size_t GROUP_SIZE = 128; + std::vector> src0_tensor_vec_vec; + std::vector> src1_tensor_vec_vec; + std::vector> dst_tensor_vec_vec; + + // split and call GroupedMatmulV2 + for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) { + size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size()); + std::vector src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end); + std::vector src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end); + std::vector dst_tensor_vec_split(dst_tensor_vec.begin() + i, dst_tensor_vec.begin() + end); + + aclTensorList* src0_tensor_list = aclCreateTensorList(src0_tensor_vec_split.data(), src0_tensor_vec_split.size()); + aclTensorList* src1_tensor_list = aclCreateTensorList(src1_tensor_vec_split.data(), src1_tensor_vec_split.size()); + aclTensorList* dst_tensor_list = aclCreateTensorList(dst_tensor_vec_split.data(), dst_tensor_vec_split.size()); + + GGML_CANN_CALL_ACLNN_OP(ctx, GroupedMatmulV2, src1_tensor_list, src0_tensor_list, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, -1, dst_tensor_list); + + ggml_cann_release_resources(ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list); + } + return; +} + +void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + const enum ggml_type type = dst->src[0]->type; + switch (type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + ggml_cann_mul_mat_id_fp(ctx, dst); + break; + default: + GGML_ABORT("Unsupported type for mul_mat_id"); + break; + } +} diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 462351542e5..15993cce66f 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -978,6 +978,33 @@ inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffe } } +/** + * @brief Performs sparse expert-based matrix multiplication using the CANN backend. + * + * @details This function implements a MoE-style batched matrix multiplication, where each input token + * is routed to one or more experts, and each expert corresponds to a specific [D, M] weight matrix + * in the source tensor `src0`. The routing indices are provided via the `ids` tensor. + * + * For each token (from `src1`), the function selects the corresponding expert(s) as specified by `ids`, + * performs the matrix multiplication with the selected expert's weight submatrix (from `src0`), + * and stores the results in `dst`. This operation is optimized and executed on the CANN backend. + * + * Dimensions: + * - src0: [D, M, A, 1], where A is the number of experts + * - src1: [D, B, N, 1], where N is batch size and B is the slot count per sample + * - ids : [K, N], where K is the number of experts each token is routed to + * - dst : [M, K, N, 1], output tensor storing the result of expert × token multiplication + * + * The function handles two main modes: + * - If `ne12 == 1`, a simpler per-token loop is used. + * - TODO: If `ne12 > 1`, grouped multiplication and memory copying is used for efficiency. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the expert-weighted token outputs are stored. + * Expected to be of shape [M, K, N, 1]. + */ +void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst); + /** * @brief Applies a element-wise operation to two input tensors using the CANN * backend. diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index e2617b06e9c..0cb7bbf17cc 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1672,7 +1672,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, ggml_cann_mul_mat(ctx, dst); break; case GGML_OP_MUL_MAT_ID: - return false; + ggml_cann_mul_mat_id(ctx, dst); + break; case GGML_OP_SCALE: ggml_cann_scale(ctx, dst); break; @@ -2030,7 +2031,13 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, } } case GGML_OP_MUL_MAT_ID: - return false; + switch (op->src[0]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + default: + return false; + } // embedding case GGML_OP_GET_ROWS: { switch (op->src[0]->type) { diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index bdaec2881dd..1d4259dae5b 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -385,9 +385,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.5.0") + set(KLEIDIAI_COMMIT_TAG "v1.6.0") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "ea22e1aefb800e9bc8c74d91633cc58e") + set(KLEIDIAI_ARCHIVE_MD5 "75b4ad68f25ab673dcc01065e5a0b05f") if (POLICY CMP0135) cmake_policy(SET CMP0135 NEW) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index ccd0651ebc7..a89ce9bb1e9 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -8519,7 +8519,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); +#ifdef __ARM_FEATURE_MATMUL_INT8 + assert((nrc == 2) || (nrc == 1)); +#else assert(nrc == 1); +#endif UNUSED(nrc); UNUSED(bx); UNUSED(by); @@ -8530,6 +8534,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q6_K * GGML_RESTRICT x0 = x; + const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx); + const block_q8_K * GGML_RESTRICT y0 = y; + const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by); + + float32x4_t vfsum = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) { + const uint8_t * GGML_RESTRICT ql0 = x0->ql; + const uint8_t * GGML_RESTRICT ql1 = x1->ql; + const uint8_t * GGML_RESTRICT qh0 = x0->qh; + const uint8_t * GGML_RESTRICT qh1 = x1->qh; + const int8_t * GGML_RESTRICT qy0 = y0->qs; + const int8_t * GGML_RESTRICT qy1 = y1->qs; + + const uint8x16_t mone = vdupq_n_u8(0x30); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + int32x4_t visum = vdupq_n_s32(0); + + // process 8 blocks per iteration, totally 16 blocks + for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) { + int8x16_t vx0[8], vx1[8]; + + // de-quantize vx0[8] + { + const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0); + const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0); + + uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4)); + uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4)); + uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2)); + uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2)); + + vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0)); + vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1)); + vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2)); + vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3)); + + q6h_0 = vandq_u8(mone, qh_bits.val[0]); + q6h_1 = vandq_u8(mone, qh_bits.val[1]); + q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2)); + q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2)); + + vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0)); + vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1)); + vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2)); + vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3)); + } + + // de-quantize vx1[8] + { + const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1); + const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1); + + uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4)); + uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4)); + uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2)); + uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2)); + + vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0)); + vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1)); + vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2)); + vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3)); + + q6h_0 = vandq_u8(mone, qh_bits.val[0]); + q6h_1 = vandq_u8(mone, qh_bits.val[1]); + q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2)); + q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2)); + + vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0)); + vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1)); + vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2)); + vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3)); + } + + // process 16 elements (one block with same scale) per iteration + // - vx = concat(ql, qh) - 32 + // - r1,r2,r3,r4 = smmla(vx, vy) + for (int k = 0; k < 8; ++k) { + const int blk = j * 8 + k; + + const int8x16_t vy0 = vld1q_s8(qy0); + const int8x16_t vy1 = vld1q_s8(qy1); + qy0 += 16; + qy1 += 16; + + const int32x4_t block_scale = { + x0->scales[blk], + x0->scales[blk], + x1->scales[blk], + x1->scales[blk], + }; + + // calculate four results at once with outer product + const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k]))); + const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k]))); + const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1))); + const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1))); + int32x4_t vr = vdupq_n_s32(0); + vr = vmmlaq_s32(vr, vx_l, vy_l); + vr = vmmlaq_s32(vr, vx_h, vy_h); + + // apply block scale, will NOT overflow + // block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits + visum = vmlaq_s32(visum, vr, block_scale); + } + } + + // adjust bias, apply superblock scale + { + int32_t bias[4]; +#ifdef __ARM_FEATURE_SVE + const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8); + const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8); + const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums); + const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8); + const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums); + const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8); + const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales)); + const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8)); + const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales)); + const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8)); + const svint64_t zero = svdup_n_s64(0); + bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0), + svdot_s64(zero, y0_q8sums_1, x0_q6scales_1))); + bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0), + svdot_s64(zero, y1_q8sums_1, x0_q6scales_1))); + bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0), + svdot_s64(zero, y0_q8sums_1, x1_q6scales_1))); + bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0), + svdot_s64(zero, y1_q8sums_1, x1_q6scales_1))); +#else + // NEON doesn't support int16 dot product, fallback to separated mul and add + const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums); + const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums); + + int8x16_t scales_s8 = vld1q_s8(x0->scales); + const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}}; + scales_s8 = vld1q_s8(x1->scales); + const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}}; + + int32x4_t prod; + prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])), + vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])), + vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1])))); + bias[0] = vaddvq_s32(prod); + prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])), + vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])), + vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1])))); + bias[1] = vaddvq_s32(prod); + prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])), + vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])), + vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1])))); + bias[2] = vaddvq_s32(prod); + prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])), + vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])), + vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1])))); + bias[3] = vaddvq_s32(prod); + +#endif + const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32); + + const float32x4_t superblock_scale = { + GGML_FP16_TO_FP32(x0->d) * y0->d, + GGML_FP16_TO_FP32(x0->d) * y1->d, + GGML_FP16_TO_FP32(x1->d) * y0->d, + GGML_FP16_TO_FP32(x1->d) * y1->d, + }; + + visum = vsubq_s32(visum, vibias); + vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale); + } + } + + // vfsum = ABCD -> ACBD + // AC -> s, BD -> (s+bs) + vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2)); + vst1_f32(s, vget_low_f32 (vfsum)); + vst1_f32(s + bs, vget_high_f32(vfsum)); + + return; + } +#endif + #ifdef __ARM_FEATURE_SVE const int vector_length = ggml_cpu_get_sve_cnt()*8; float sum = 0; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index a30e67f2279..133b50606bc 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -282,7 +282,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .from_float = quantize_row_q6_K, .vec_dot = ggml_vec_dot_q6_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, +#if defined (__ARM_FEATURE_MATMUL_INT8) + .nrows = 2, +#else .nrows = 1, +#endif }, [GGML_TYPE_IQ2_XXS] = { .from_float = NULL, diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h index 5ac02bda7c0..3b268d4a22a 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.h +++ b/ggml/src/ggml-cpu/kleidiai/kernels.h @@ -5,6 +5,7 @@ #pragma once #include +#include #include "ggml.h" enum cpu_feature { diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index f3dffdd6bf5..15f0cd15406 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -3,7 +3,9 @@ // #include #include +#include #include +#include #include #include #if defined(__linux__) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index b7180d5955c..a4fbd823638 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -678,10 +678,14 @@ void launch_fattn( ) { constexpr int ncols = ncols1 * ncols2; + const bool is_mla = DV == 512; // TODO better parameterization + const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; + GGML_ASSERT(V || is_mla); + const ggml_tensor * mask = dst->src[3]; ggml_tensor * KQV = dst; @@ -689,6 +693,10 @@ void launch_fattn( GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(KQV->type == GGML_TYPE_F32); + GGML_ASSERT( Q->nb[0] == ggml_element_size(Q)); + GGML_ASSERT( K->nb[0] == ggml_element_size(K)); + GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V)); + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); @@ -713,10 +721,10 @@ void launch_fattn( size_t nb12 = K->nb[2]; size_t nb13 = K->nb[3]; - const char * V_data = (const char *) V->data; - size_t nb21 = V->nb[1]; - size_t nb22 = V->nb[2]; - size_t nb23 = V->nb[3]; + const char * V_data = V ? (const char *) V->data : nullptr; + size_t nb21 = V ? V->nb[1] : nb11; + size_t nb22 = V ? V->nb[2] : nb12; + size_t nb23 = V ? V->nb[3] : nb13; if (need_f16_K && K->type != GGML_TYPE_F16) { GGML_ASSERT(ggml_is_contiguously_allocated(K)); @@ -733,7 +741,7 @@ void launch_fattn( nb13 = nb13*bs*sizeof(half)/ts; } - if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V && need_f16_V && V->type != GGML_TYPE_F16) { GGML_ASSERT(ggml_is_contiguously_allocated(V)); V_f16.alloc(ggml_nelements(V)); to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 491780abd40..be0329d0e0c 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -33,9 +33,30 @@ struct fattn_mma_f16_config< 64, 64> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 32; - static constexpr int nbatch_V2 = 32; - static constexpr int nbatch_combine = 32; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 32; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 32; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 32; + } }; template <> @@ -44,9 +65,30 @@ struct fattn_mma_f16_config< 80, 80> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 40; - static constexpr int nbatch_V2 = 40; - static constexpr int nbatch_combine = 40; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 40; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 40; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 40; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 40; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 40; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 40; + } }; template <> @@ -55,9 +97,30 @@ struct fattn_mma_f16_config< 96, 96> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 48; - static constexpr int nbatch_V2 = 48; - static constexpr int nbatch_combine = 48; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 48; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 48; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 48; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 48; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 48; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 48; + } }; template <> @@ -66,9 +129,30 @@ struct fattn_mma_f16_config<112, 112> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 56; - static constexpr int nbatch_V2 = 56; - static constexpr int nbatch_combine = 56; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 56; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 56; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 56; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 56; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 56; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 56; + } }; template <> @@ -77,9 +161,30 @@ struct fattn_mma_f16_config<128, 128> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 64; - static constexpr int nbatch_V2 = 64; - static constexpr int nbatch_combine = 64; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 64; + } }; template <> @@ -88,9 +193,38 @@ struct fattn_mma_f16_config<256, 256> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 128; - static constexpr int nbatch_V2 = 128; - static constexpr int nbatch_combine = 128; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 128; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 128; + } + + static int get_nbatch_combine_host(const int cc, const int ncols) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { + return ncols <= 16 ? 128 : 64; + } + return 64; + } + + static constexpr __device__ int get_nbatch_combine_device(int ncols) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + return ncols <= 16 ? 128 : 64; +#else + GGML_UNUSED(ncols); + return 128; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING + } }; template <> @@ -99,9 +233,44 @@ struct fattn_mma_f16_config<576, 512> { static constexpr int nwarps_max = 8; static constexpr bool Q_in_reg = false; static constexpr int nstages_target = 1; - static constexpr int nbatch_K2 = 160; - static constexpr int nbatch_V2 = 128; - static constexpr int nbatch_combine = 128; + + static int get_nbatch_K2_host(const int cc, const int ncols) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { + return ncols <= 16 ? 96 : 160; + } + return ncols <= 16 ? 288 : 160; + } + + static constexpr __device__ int get_nbatch_K2_device(int ncols) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + return ncols <= 16 ? 96 : 160; +#else + return ncols <= 16 ? 288 : 160; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING + } + + static int get_nbatch_V2_host(const int cc, const int ncols) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { + return ncols <= 16 ? 64 : 128; + } + return ncols <= 16 ? 256 : 128; + } + + static constexpr __device__ int get_nbatch_V2_device(int ncols) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + return ncols <= 16 ? 64 : 128; +#else + return ncols <= 16 ? 256 : 128; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 128; + } }; // ------------------------------------------------------------------------------------------------------------------ @@ -120,7 +289,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); - auto load = [&] __device__ (const int n) { + auto load = [&] __device__ (auto n) { const int stride_k = WARP_SIZE >> n; const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); @@ -223,7 +392,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -261,10 +430,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int ncols = ncols1 * ncols2; + constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); + constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); - constexpr int stride_tile_Q = DKQ/2 + 4; - constexpr int stride_tile_K = c::nbatch_K2 + 4; - constexpr int stride_tile_V = c::nbatch_V2 + 4; + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = nbatch_K2 + 4; + + static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); + constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; const int k_VKQ_0 = kb0 * c::nbatch_fa; tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; @@ -275,12 +449,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; if constexpr (nstages > 1) { - static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); + static_assert(!mla, "multi-stage loading not implemented for MLA"); + static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); constexpr bool use_cp_async = true; cp_async_wait_all(); __syncthreads(); flash_attn_ext_f16_load_tile - (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V); + (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V); } else { constexpr bool use_cp_async = nstages == 1; if (ncols2 > 1 || mask_h2) { @@ -289,8 +464,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } #pragma unroll - for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) { - const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2; + for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) { + const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; const int k0_diff = k0_stop - k0_start; if (nstages <= 1) { @@ -537,16 +712,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); } flash_attn_ext_f16_load_tile - (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K); + (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); } } + + // For MLA K and V have the same data. + // Therefore, iterate over V in reverse and re-use the data if possible. + static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); + constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; #pragma unroll - for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) { - const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV; - const int i0_diff = i0_stop - i0_start; + for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) { + const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0; + const int i0_diff = i0_stop - i0_start; - if (nstages <= 1) { + if (nstages <= 1 && i0_start < reusable_cutoff) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); @@ -555,6 +735,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } __syncthreads(); } + const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; // Calculate VKQ tile: #pragma unroll @@ -565,7 +746,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0 = k00 + (threadIdx.y % np)*tile_A::J; tile_A A; - load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); if (ntiles == 1) { mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); } else { @@ -596,7 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // NEW_MMA_AVAILABLE } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -632,13 +813,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); + constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); - constexpr int stride_tile_Q = DKQ/2 + 4; - constexpr int stride_tile_K = c::nbatch_K2 + 4; - constexpr int stride_tile_V = c::nbatch_V2 + 4; + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = nbatch_K2 + 4; + static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); + constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; extern __shared__ half2 tile_Q[]; @@ -726,26 +910,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Preload mask and K data for first iteration when using cp_async with multiple stages: if constexpr (nstages > 1) { - static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); + static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); constexpr bool use_cp_async = true; if (ncols2 > 1 || mask_h2) { flash_attn_ext_f16_load_mask (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); } flash_attn_ext_f16_load_tile - (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K); + (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); } // Iterate over ne11 == previous tokens: for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. constexpr bool last_iter = true; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); } @@ -774,7 +958,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // So also write VKQ accumulators to shared memory in column-major format if np == 1. - constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4; + constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols); constexpr int tile_stride = nbatch_combine + 4; static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); @@ -1012,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #endif // NEW_MMA_AVAILABLE } -template +template __launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -1057,6 +1241,14 @@ static __global__ void flash_attn_ext_f16( NO_DEVICE_CODE; return; } +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + if (ncols1*ncols2 > 32) { + NO_DEVICE_CODE; + return; + } +#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING + + static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); typedef fattn_mma_f16_config c; @@ -1067,9 +1259,10 @@ static __global__ void flash_attn_ext_f16( const int stride_Q1 = nb01 / sizeof(float2); const int stride_Q2 = nb02 / sizeof(float2); const int stride_K = nb11 / sizeof(half2); - const int stride_V = nb21 / sizeof(half2); const int stride_mask = nb31 / sizeof(half2); + const int stride_V = mla ? stride_K : nb21 / sizeof(half2); + const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; @@ -1092,10 +1285,11 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; @@ -1104,12 +1298,12 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } @@ -1130,10 +1324,11 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; @@ -1141,7 +1336,7 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else @@ -1167,10 +1362,6 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml typedef fattn_mma_f16_config c; - constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2; - constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2; - constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine; - const int nstages = cp_async_available(cc) ? c::nstages_target : 0; constexpr int ncols = ncols1 * ncols2; @@ -1180,15 +1371,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; + constexpr bool mla = DKQ == 576; + + const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols); + const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols); + const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols); + static_assert(DKQ % tile_B::J == 0, "bad DKQ"); static_assert(DV % tile_A::J == 0, "bad DV"); static_assert(ncols % cols_per_warp == 0, "bad ncols"); - const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2); - const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2); - const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); - const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); - const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); + const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); + const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); + const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; @@ -1202,7 +1399,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1213,7 +1410,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 9c5c803d02b..6bc0096cc65 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -10,6 +10,7 @@ template static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const ggml_tensor * Q = dst->src[0]; if constexpr (ncols2 <= 8) { @@ -24,7 +25,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con return; } - if (Q->ne[1] <= 32/ncols2) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b4b85abcda9..02dc8c12dbd 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3222,7 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g #endif // FLASH_ATTN_AVAILABLE if (op->src[1]->ne[0] != op->src[2]->ne[0]) { const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; - if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) { + if (!new_mma_available(cc)) { return false; } const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2]; diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index e1cf843de1a..2db5b4ab0f0 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -122,6 +122,7 @@ void ggml_cuda_mul_mat_q( const int64_t s13 = src1->nb[3] / ts_src1; quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); + CUDA_CHECK(cudaGetLastError()); } const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); @@ -205,6 +206,7 @@ void ggml_cuda_mul_mat_q( const int64_t s13 = src1->nb[2] / ts_src1; quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + CUDA_CHECK(cudaGetLastError()); } const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index cb93181455d..a0b03a740d7 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1( constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; - const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4; + const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4; if (i0 >= ne0) { return; } - const int64_t i1 = blockIdx.y; + const int64_t i1 = blockIdx.x; const int64_t i2 = blockIdx.z % ne2; const int64_t i3 = blockIdx.z / ne2; @@ -75,8 +75,8 @@ static __global__ void quantize_mmq_q8_1( block_q8_1_mmq * y = (block_q8_1_mmq *) vy; - const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel - const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel + const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel + const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel const int64_t iqs = i0 % (4*QK8_1); // quant index in block // Load 4 floats per thread and calculate max. abs. value between them: @@ -166,8 +166,9 @@ void quantize_mmq_q8_1_cuda( GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ne0 % (4*QK8_1) == 0); - const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); - const dim3 num_blocks(block_num_x, ne1, ne2*ne3); + // ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid: + const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); + const dim3 num_blocks(ne1, block_num_y, ne2*ne3); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); switch (mmq_get_q8_1_ds_layout(type_src0)) { case MMQ_Q8_1_DS_LAYOUT_D4: diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 576f9581bda..85dbbcd5d7f 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -415,6 +415,13 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, @@ -1362,6 +1369,13 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction); @@ -4358,7 +4372,7 @@ static bool ggml_metal_encode_node( // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) // for now avoiding mainly to keep the number of templates/kernels a bit lower // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612 - if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) { + if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) { switch (src1->type) { case GGML_TYPE_F16: { @@ -4539,6 +4553,24 @@ static bool ggml_metal_encode_node( use_vec_kernel = true; switch (ne00) { + case 64: + { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } break; case 96: { switch (src1->type) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 9cfddf4503a..e94b6cd7564 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3887,6 +3887,11 @@ kernel void kernel_flash_attn_ext_vec( sm[tiisg] = pm[ic + tiisg]; } + // skip -INF blocks + if (simd_max(sm[tiisg]) == -INFINITY) { + continue; + } + // Q*K^T { // each simdgroup processes 1 query and NE (NW/NL) head elements @@ -4119,6 +4124,16 @@ kernel void kernel_flash_attn_ext_vec( typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; +template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; diff --git a/ggml/src/ggml-opt.cpp b/ggml/src/ggml-opt.cpp index 58d77578f45..a3c82d67577 100644 --- a/ggml/src/ggml-opt.cpp +++ b/ggml/src/ggml-opt.cpp @@ -576,6 +576,10 @@ void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) { } } +bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) { + return opt_ctx->static_graphs; +} + struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) { return opt_ctx->inputs; } @@ -842,6 +846,7 @@ void ggml_opt_epoch( int64_t idata_split, ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval) { + GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && "ggml_opt_epoch requires static graphs"); struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx); struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); struct ggml_tensor * data = ggml_opt_dataset_data(dataset); diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 231fb71dab5..a2e26124802 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -49,34 +49,38 @@ endif() target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing") # Link against oneDNN -find_package(DNNL) set(GGML_SYCL_DNNL 0) -if(DNNL_FOUND) - if (NOT DEFINED DNNL_GPU_VENDOR) - # default to intel target - set(DNNL_GPU_VENDOR "INTEL") - if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL") - message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target") +if(GGML_SYCL_DNN) + find_package(DNNL) + if(DNNL_FOUND) + if (NOT DEFINED DNNL_GPU_VENDOR) + # default to intel target + set(DNNL_GPU_VENDOR "INTEL") + if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL") + message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target") + endif() endif() - endif() - # Verify oneDNN was compiled for the same target as llama - if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}") - target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl) - set(GGML_SYCL_DNNL 1) - get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS) - foreach(CONFIG ${CONFIGS}) - get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG}) - message(STATUS "Found oneDNN: ${DNNL_LIB}") - endforeach() + # Verify oneDNN was compiled for the same target as llama + if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}") + target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl) + set(GGML_SYCL_DNNL 1) + get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS) + foreach(CONFIG ${CONFIGS}) + get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG}) + message(STATUS "Found oneDNN: ${DNNL_LIB}") + endforeach() + else() + message(WARNING + "oneDNN must be compiled for the same target as llama.cpp. + llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}. + Disabling oneDNN support.") + endif() else() - message(WARNING - "oneDNN must be compiled for the same target as llama.cpp. - llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}. - Disabling oneDNN support.") + message(STATUS "oneDNN not found, disabling oneDNN support") endif() else() - message(STATUS "oneDNN not found, disabling oneDNN support") + message(STATUS "oneDNN support disabled by the user") endif() target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL}) diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp index 0a9d3a927c2..aaa94176f16 100644 --- a/ggml/src/ggml-sycl/binbcast.cpp +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -1,93 +1,74 @@ #include "binbcast.hpp" +#include #include #include #include +#include "dpct/helper.hpp" #include "ggml.h" -template -static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13, - const sycl::nd_item<3> &item_ct1) { - const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1)); - const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + - item_ct1.get_local_id(0)) / - ne3; - const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + - item_ct1.get_local_id(0)) % - ne3; - - if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { - return; - } - - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; - - const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; - const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i3*s3 + i2*s2 + i1*s1; - - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; - dst_t * dst_row = dst + i_dst; - - for (int i0 = i0s; i0 < ne0; - i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) { - const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); +template +static __dpct_inline__ void k_bin_bcast_contiguous(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, + dst_t * dst, std::size_t num_elements, const sycl::nd_item<1> & it) { + auto element_id = it.get_global_id(0); + auto global_range = it.get_global_range(0); + for (; element_id < num_elements; element_id += global_range) { + auto src0_float_val = sycl::vec(src0[element_id]).template convert(); + auto src1_float_val = sycl::vec(src1[element_id]).template convert(); + float dst_val = bin_op(src0_float_val[0], src1_float_val[0]); + auto val_to_store = sycl::vec(dst_val).template convert(); + dst[element_id] = val_to_store; } } -template -static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13, - const sycl::nd_item<3> &item_ct1) { - - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - const int i3 = i/(ne2*ne1*ne0); - const int i2 = (i/(ne1*ne0)) % ne2; - const int i1 = (i/ne0) % ne1; - const int i0 = i % ne0; - - if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { - return; +template +static __dpct_inline__ void k_bin_bcast(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13, + int s0, int s1, int s2, int s3, int s00, int s01, int s02, int s03, int s10, + int s11, int s12, int s13, std::size_t num_dst_elements, + const sycl::nd_item<1> & item_ct1) { + auto calculate_logical_index = + [](const std::array & dims, std::size_t element_id) __attribute__((always_inline))->std::array { + std::array logical_index; +#pragma unroll(4) + for (int i = 3; i >= 0; i--) { + logical_index[i] = element_id % dims[i]; + element_id /= dims[i]; + } + return logical_index; + }; + + auto calculate_index = [](const std::array & dims, const std::array & strides, + const std::array & indices) __attribute__((always_inline)) + ->std::size_t { + std::size_t index = 0; +#pragma unroll(4) + for (int i = 0; i < 4; i++) { + auto index_i = indices[i]; + if (indices[i] >= dims[i]) { + index_i = indices[i] % dims[i]; + } + index += strides[i] * index_i; + } + return index; + }; + + auto element_id = item_ct1.get_global_id(0); + for (; element_id < num_dst_elements; element_id += item_ct1.get_global_range(0)) { + auto logical_index = calculate_logical_index({ ne3, ne2, ne1, ne0 }, element_id); + auto src_0_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s03, s02, s01, s00 }, logical_index); + auto src_1_index = calculate_index({ ne13, ne12, ne11, ne10 }, { s13, s12, s11, s10 }, logical_index); + auto dst_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s3, s2, s1, s0 }, logical_index); + auto src0_float_val = sycl::vec(src0[src_0_index]).template convert(); + auto src1_float_val = sycl::vec(src1[src_1_index]).template convert(); + float dst_val = bin_op(src0_float_val[0], src1_float_val[0]); + auto val_to_store = sycl::vec(dst_val).template convert(); + dst[dst_index] = val_to_store; } - - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; - - const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; - const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i3*s3 + i2*s2 + i1*s1; - - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; - dst_t * dst_row = dst + i_dst; - - const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); } - -template -struct bin_bcast_sycl { +template struct bin_bcast_sycl { template void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11, @@ -96,165 +77,73 @@ struct bin_bcast_sycl { const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0, const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous, const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) { - int nr0 = ne10 / ne0; - int nr1 = ne11/ne1; - int nr2 = ne12/ne2; - int nr3 = ne13/ne3; - - int nr[4] = { nr0, nr1, nr2, nr3 }; - - // collapse dimensions until first broadcast dimension - int64_t cne[] = {ne0, ne1, ne2, ne3}; - int64_t cne0[] = {ne00, ne01, ne02, ne03}; - int64_t cne1[] = {ne10, ne11, ne12, ne13}; - size_t cnb[] = {nb0, nb1, nb2, nb3}; - size_t cnb0[] = {nb00, nb01, nb02, nb03}; - size_t cnb1[] = {nb10, nb11, nb12, nb13}; - auto collapse = [](int64_t cne[]) { - cne[0] *= cne[1]; - cne[1] = cne[2]; - cne[2] = cne[3]; - cne[3] = 1; - }; - - auto collapse_nb = [](size_t cnb[], int64_t cne[]) { - cnb[1] *= cne[1]; - cnb[2] *= cne[2]; - cnb[3] *= cne[3]; - }; - - if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) { + auto check_bcast_required = [](const std::array & src_dims, + const std::array & dst_dims) -> bool { for (int i = 0; i < 4; i++) { - if (nr[i] != 1) { - break; - } - if (i > 0) { - collapse_nb(cnb, cne); - collapse_nb(cnb0, cne0); - collapse_nb(cnb1, cne1); - collapse(cne); - collapse(cne0); - collapse(cne1); + if (dst_dims[i] > src_dims[i]) { + return true; } } - } - { - int64_t ne0 = cne[0]; - int64_t ne1 = cne[1]; - int64_t ne2 = cne[2]; - int64_t ne3 = cne[3]; - - int64_t ne10 = cne1[0]; - int64_t ne11 = cne1[1]; - int64_t ne12 = cne1[2]; - int64_t ne13 = cne1[3]; - - size_t nb0 = cnb[0]; - size_t nb1 = cnb[1]; - size_t nb2 = cnb[2]; - size_t nb3 = cnb[3]; - - size_t nb00 = cnb0[0]; - size_t nb01 = cnb0[1]; - size_t nb02 = cnb0[2]; - size_t nb03 = cnb0[3]; - - size_t nb10 = cnb1[0]; - size_t nb11 = cnb1[1]; - size_t nb12 = cnb1[2]; - size_t nb13 = cnb1[3]; - - size_t s0 = nb0 / sizeof(dst_t); - size_t s1 = nb1 / sizeof(dst_t); - size_t s2 = nb2 / sizeof(dst_t); - size_t s3 = nb3 / sizeof(dst_t); - - size_t s10 = nb10 / sizeof(src1_t); - size_t s11 = nb11 / sizeof(src1_t); - size_t s12 = nb12 / sizeof(src1_t); - size_t s13 = nb13 / sizeof(src1_t); - - size_t s00 = nb00 / sizeof(src0_t); - size_t s01 = nb01 / sizeof(src0_t); - size_t s02 = nb02 / sizeof(src0_t); - size_t s03 = nb03 / sizeof(src0_t); - - GGML_UNUSED(s00); - - GGML_ASSERT(nb0 % sizeof(dst_t) == 0); - GGML_ASSERT(nb1 % sizeof(dst_t) == 0); - GGML_ASSERT(nb2 % sizeof(dst_t) == 0); - GGML_ASSERT(nb3 % sizeof(dst_t) == 0); - - GGML_ASSERT(nb00 % sizeof(src0_t) == 0); - GGML_ASSERT(nb01 % sizeof(src0_t) == 0); - GGML_ASSERT(nb02 % sizeof(src0_t) == 0); - GGML_ASSERT(nb03 % sizeof(src0_t) == 0); - - GGML_ASSERT(nb10 % sizeof(src1_t) == 0); - GGML_ASSERT(nb11 % sizeof(src1_t) == 0); - GGML_ASSERT(nb12 % sizeof(src1_t) == 0); - GGML_ASSERT(nb13 % sizeof(src1_t) == 0); - - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s10 == 1); - - const int block_size = 128; - - int64_t hne0 = std::max(ne0/2LL, 1LL); - - sycl::range<3> block_dims(1, 1, 1); - block_dims[2] = std::min(hne0, block_size); - block_dims[1] = std::min( - ne1, block_size / (unsigned int)block_dims[2]); - block_dims[0] = std::min( - std::min( - ne2 * ne3, block_size / (unsigned int)block_dims[2] / - (unsigned int)block_dims[1]), - 64U); - - sycl::range<3> block_nums( - (ne2 * ne3 + block_dims[0] - 1) / block_dims[0], - (ne1 + block_dims[1] - 1) / block_dims[1], - (hne0 + block_dims[2] - 1) / block_dims[2]); - - if (block_nums[0] > 65535) { - // this is the maximum number of blocks in z direction, fallback to 1D grid kernel - int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * - sycl::range<3>(1, 1, block_size), - sycl::range<3>(1, 1, block_size)), - [=](sycl::nd_item<3> item_ct1) { - k_bin_bcast_unravel( - src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02, - s03, s11, s12, s13, item_ct1); - }); - } - } else { - /* - DPCT1049:16: The work-group size passed to the SYCL kernel may - exceed the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if - needed. - */ - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, - ne2, ne3, ne10, ne11, ne12, ne13, - s1, s2, s3, s01, s02, s03, s11, s12, s13, - item_ct1); - }); - } + return false; + }; + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + GGML_ASSERT(nb0 % sizeof(dst_t) == 0); + GGML_ASSERT(nb1 % sizeof(dst_t) == 0); + GGML_ASSERT(nb2 % sizeof(dst_t) == 0); + GGML_ASSERT(nb3 % sizeof(dst_t) == 0); + + GGML_ASSERT(nb00 % sizeof(src0_t) == 0); + GGML_ASSERT(nb01 % sizeof(src0_t) == 0); + GGML_ASSERT(nb02 % sizeof(src0_t) == 0); + GGML_ASSERT(nb03 % sizeof(src0_t) == 0); + + GGML_ASSERT(nb10 % sizeof(src1_t) == 0); + GGML_ASSERT(nb11 % sizeof(src1_t) == 0); + GGML_ASSERT(nb12 % sizeof(src1_t) == 0); + GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + + // dst strides in number of elements + size_t s0 = nb0 / sizeof(dst_t); + size_t s1 = nb1 / sizeof(dst_t); + size_t s2 = nb2 / sizeof(dst_t); + size_t s3 = nb3 / sizeof(dst_t); + + // src1 strides in number of elements + size_t s10 = nb10 / sizeof(src0_t); + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); + + // src0 strides in number of elements + size_t s00 = nb00 / sizeof(src0_t); + size_t s01 = nb01 / sizeof(src0_t); + size_t s02 = nb02 / sizeof(src0_t); + size_t s03 = nb03 / sizeof(src0_t); + + std::size_t num_dst_elements = static_cast(ne0) * static_cast(ne1) * + static_cast(ne2) * static_cast(ne3); + std::size_t local_range = 256; + std::size_t global_range = ceil_div(num_dst_elements, local_range) * local_range; + + bool needs_broadcasting = check_bcast_required({ ne00, ne01, ne02, ne03 }, { ne0, ne1, ne2, ne3 }) || + check_bcast_required({ ne10, ne11, ne12, ne13 }, { ne0, ne1, ne2, ne3 }); + bool all_contiguous = src0_is_contiguous && src1_is_contiguous && dst_is_contiguous; + + if (! needs_broadcasting && all_contiguous) { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) { + k_bin_bcast_contiguous(src0_dd, src1_dd, dst_dd, num_dst_elements, it); + }); + }); + } else { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) { + k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s0, s1, + s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, num_dst_elements, it); + }); + }); } } }; diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index b2f8a656933..75bac98e5fb 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -183,6 +183,24 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k, } } +template +static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + const size_t local_size = 32; + const size_t global_size = nb * local_size; + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor scale_local_acc(sycl::range<1>(12), cgh); + + cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)), + [=](sycl::nd_item<1> item_ct1) { + dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb); + }); + }); +} + template static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -504,7 +522,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { case GGML_TYPE_Q3_K: return dequantize_row_q3_K_sycl; case GGML_TYPE_Q4_K: - return dequantize_row_q4_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q4_K_sycl_reorder; + } else { + return dequantize_row_q4_K_sycl; + } case GGML_TYPE_Q5_K: return dequantize_row_q5_K_sycl; case GGML_TYPE_Q6_K: @@ -556,7 +578,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { case GGML_TYPE_Q3_K: return dequantize_row_q3_K_sycl; case GGML_TYPE_Q4_K: - return dequantize_row_q4_K_sycl; + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q4_K_sycl_reorder; + } else { + return dequantize_row_q4_K_sycl; + } case GGML_TYPE_Q5_K: return dequantize_row_q5_K_sycl; case GGML_TYPE_Q6_K: diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index 651c2160d24..64e92f73f26 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -357,6 +357,28 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8 } #endif +template +inline void dequantize_q4_K_common(dst_t * __restrict__ y, const uint8_t * __restrict__ qs_ptr, const float dall, + const float dmin, uint8_t * __restrict__ scales_local, int il, int ir) { + const int is = 2 * il; + constexpr int n = 4; + + uint8_t sc, m; + get_scale_min_k4(is + 0, scales_local, sc, m); + const float d1 = dall * sc; + const float m1 = dmin * m; + + get_scale_min_k4(is + 1, scales_local, sc, m); + const float d2 = dall * sc; + const float m2 = dmin * m; + + sycl::vec q_vec = vec_aligned_load(qs_ptr + 32 * il + n * ir); + for (int l = 0; l < n; ++l) { + y[l + 0] = d1 * (q_vec[l] & 0xF) - m1; + y[l + 32] = d2 * (q_vec[l] >> 4) - m2; + } +} + template static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) { @@ -365,36 +387,22 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri const int64_t i = item_ct1.get_group(2); #if QK_K == 256 - // assume 32 threads const int64_t tid = item_ct1.get_local_id(2); - const int64_t il = tid/8; - const int64_t ir = tid%8; - const int64_t is = 2*il; - const int64_t n = 4; + const int64_t il = tid / 8; + const int64_t ir = tid % 8; - dst_t * y = yy + i*QK_K + 64*il + n*ir; + dst_t * y = yy + i * QK_K + 64 * il + 4 * ir; const sycl::half2 dm = x[i].dm; const float dall = dm[0]; const float dmin = dm[1]; - if (tid < 12) + if (tid < 12) { scales_local[tid] = x[i].scales[tid]; - item_ct1.barrier(sycl::access::fence_space::local_space); - - uint8_t sc, m; - get_scale_min_k4(is + 0, scales_local, sc, m); - const float d1 = dall * sc; - const float m1 = dmin * m; - get_scale_min_k4(is + 1, scales_local, sc, m); - const float d2 = dall * sc; - const float m2 = dmin * m; - - sycl::vec q_vec = vec_aligned_load(x[i].qs + 32*il + n*ir); - for (int l = 0; l < n; ++l) { - y[l + 0] = d1 * (q_vec[l] & 0xF) - m1; - y[l +32] = d2 * (q_vec[l] >> 4) - m2; } + + item_ct1.barrier(sycl::access::fence_space::local_space); + dequantize_q4_K_common(y, x[i].qs, dall, dmin, scales_local, il, ir); #else const int64_t tid = item_ct1.get_local_id(2); const uint8_t * q = x[i].qs; @@ -406,6 +414,36 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri #endif } +template +static void dequantize_block_q4_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t * scales_local, + const sycl::nd_item<1> & item_ct1, int64_t nb) { + const int64_t i = item_ct1.get_group(0); // block index + const int64_t tid = item_ct1.get_local_id(0); // thread index within block + const int64_t il = tid / 8; + const int64_t ir = tid % 8; + + dst_t * y = yy + i * QK_K + 64 * il + 4 * ir; + + const uint8_t * base = static_cast(vx); + const size_t qs_offset = i * (QK_K / 2); + const size_t scales_offset = nb * (QK_K / 2) + i * K_SCALE_SIZE; + const size_t dm_offset = nb * (QK_K / 2) + nb * K_SCALE_SIZE + i * sizeof(ggml_half2); + + const uint8_t * qs_ptr = base + qs_offset; + const uint8_t * scales_ptr = base + scales_offset; + ggml_half2 dm_values = *reinterpret_cast(base + dm_offset); + + const float dall = dm_values.x(); + const float dmin = dm_values.y(); + + if (tid < 12) { + scales_local[tid] = scales_ptr[tid]; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + dequantize_q4_K_common(y, qs_ptr, dall, dmin, scales_local, il, ir); +} + template static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy, const sycl::nd_item<3> &item_ct1) { diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 04a85fa35ff..b58150c687b 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -1129,7 +1129,13 @@ void ggml_sycl_op_dequantize_mul_mat_vec( dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); break; case GGML_TYPE_Q4_K: - dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + // reorder is currently not supported for dmmv + GGML_ABORT("Unimplemented dequantize case case for q4_k reorder"); + } else { + dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } break; case GGML_TYPE_Q5_K: dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index dcc6ec809a7..becaac4048a 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -655,7 +655,6 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -688,7 +687,6 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -722,7 +720,6 @@ inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -754,7 +751,6 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -786,7 +782,6 @@ inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -818,7 +813,6 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -850,7 +844,6 @@ inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -883,7 +876,6 @@ inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -917,7 +909,6 @@ inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tenso } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -949,7 +940,6 @@ inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -981,7 +971,6 @@ inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1013,7 +1002,6 @@ inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1045,7 +1033,6 @@ inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1078,7 +1065,6 @@ inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1110,7 +1096,6 @@ inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1142,7 +1127,6 @@ inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1174,7 +1158,6 @@ inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1206,7 +1189,6 @@ inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1241,7 +1223,6 @@ inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1273,7 +1254,6 @@ inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1315,7 +1295,6 @@ inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1350,7 +1329,6 @@ inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1388,7 +1366,6 @@ inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * ds } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index 4ebbb5b66fb..6cbc7e0f693 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -32,16 +32,36 @@ class DnnlGemmWrapper { else static_assert(0); } - static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k, - const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { + // matrix A has m rows, k columns + // matrix B has k rows, n columns + // nra - number of elements to skip when moving into next row in A + // nrb - number of elements to skip when moving into next row in B + // nca - number of elements to skip when moving into next column in A + // ncb - number of elements to skip when moving into next column in B + // stride_a - number of elements to skip when moving to next A matrix + // stride_b - number of elements to skip when moving to next B matrix + // batches_a - number of A matrices + // batches_b - number of B matrices + static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, + const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a, + const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b, + void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) { + auto stream = ctx.stream_dnnl(q); auto eng = ctx.engine_dnnl(q); - dnnl::memory::dims a_dims = { m, k }; - dnnl::memory::dims b_dims = { k, n }; - dnnl::memory::dims c_dims = { m, n }; - const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); - const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); - const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); + + // { # strides, # rows, # columns } + dnnl::memory::dims a_dims = { batches_a, m, k }; + dnnl::memory::dims b_dims = { batches_b, k, n }; + dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n }; + + // { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column } + dnnl::memory::dims a_strides = { stride_a, nra, nca }; + dnnl::memory::dims b_strides = { stride_b, nrb, ncb }; + + const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides); + const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides); + const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc); dnnl::primitive_attr primitive_attr; primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); @@ -63,6 +83,15 @@ class DnnlGemmWrapper { matmul_prim.execute(stream, matmul_args); } + + // matrices A and B are column major, both having k rows + // matrix A has m column, matrix B has n columns + // output: column major matrix C = A transposed * B + static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, + const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { + + gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1); + } }; #endif diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 0ea729948ec..5ff7fa13db0 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -49,6 +49,7 @@ static bool g_sycl_loaded = false; int g_ggml_sycl_debug = 0; int g_ggml_sycl_disable_optimize = 0; int g_ggml_sycl_disable_graph = 0; +int g_ggml_sycl_disable_dnn = 0; int g_ggml_sycl_prioritize_dmmv = 0; static ggml_sycl_device_info ggml_sycl_init() { @@ -196,12 +197,22 @@ static void ggml_check_sycl() try { g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1); g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); + g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize); +#ifdef GGML_SYCL_GRAPH GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph); +#else + GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n"); +#endif +#if GGML_SYCL_DNNL + GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn); +#else + GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n"); +#endif GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); GGML_LOG_INFO("Build with Macros:\n"); #if defined(GGML_SYCL_FORCE_MMQ) @@ -341,7 +352,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, assert(tensor->view_src->buffer->buft == buffer->buft); return GGML_STATUS_SUCCESS; } - if (tensor->type == GGML_TYPE_Q4_0 && !g_ggml_sycl_disable_optimize) { + if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) { ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; tensor->extra = extra; ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx. @@ -1985,19 +1996,18 @@ inline void ggml_sycl_op_mul_mat_sycl( const int64_t ne00 = src0->ne[0]; const int64_t ne10 = src1->ne[0]; - + GGML_ASSERT(ne00 == ne10); const int64_t row_diff = row_high - row_low; int id; SYCL_CHECK( CHECK_TRY_ERROR(id = get_current_device_id())); -#if !GGML_SYCL_DNNL - const int64_t ne0 = dst->ne[0]; + + const int64_t ne0 = dst->ne[0]; // used by MKL only // the main device has a larger memory buffer to hold the results from all GPUs // ldc == nrows of the matrix that cuBLAS writes into - int ldc = id == ctx.device ? ne0 : row_diff; -#endif + int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only #ifdef GGML_SYCL_F16 bool use_fp16 = true; // TODO(Yu) SYCL capability check @@ -2033,25 +2043,29 @@ inline void ggml_sycl_op_mul_mat_sycl( : src1_as_f16.get(); ggml_sycl_pool_alloc dst_f16(ctx.pool(), row_diff * src1_ncols); -#if !GGML_SYCL_DNNL - const sycl::half alpha_f16 = 1.0f; - const sycl::half beta_f16 = 0.0f; - SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( - *stream, oneapi::math::transpose::trans, - oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10, - &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00, - src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, - dst_f16.get(), dpct::library_data_t::real_half, ldc, - dpct::library_data_t::real_half))); - const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); - to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); -#else - DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr, - DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(), - dst_f16.get(), DnnlGemmWrapper::to_dt(), stream); - const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); - to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); +#if GGML_SYCL_DNNL + if (!g_ggml_sycl_disable_dnn) { + DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr, + DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(), + dst_f16.get(), DnnlGemmWrapper::to_dt(), stream); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); + to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); + } + else #endif + { + const sycl::half alpha_f16 = 1.0f; + const sycl::half beta_f16 = 0.0f; + SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( + *stream, oneapi::math::transpose::trans, + oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10, + &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00, + src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, + dst_f16.get(), dpct::library_data_t::real_half, ldc, + dpct::library_data_t::real_half))); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); + to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); + } } else { // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n"); @@ -2072,18 +2086,22 @@ inline void ggml_sycl_op_mul_mat_sycl( const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); -#if !GGML_SYCL_DNNL - const float alpha = 1.0f; - const float beta = 0.0f; - SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm( - get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff, - src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, - dpct::get_value(&beta, *stream), dst_dd_i, ldc))); -#else - DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, - DnnlGemmWrapper::to_dt(), src0_ddf_i, DnnlGemmWrapper::to_dt(), - dst_dd_i, DnnlGemmWrapper::to_dt(), stream); +#if GGML_SYCL_DNNL + if (!g_ggml_sycl_disable_dnn) { + DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i, + DnnlGemmWrapper::to_dt(), src0_ddf_i, DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); + } + else #endif + { + const float alpha = 1.0f; + const float beta = 0.0f; + SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm( + get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff, + src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, + dpct::get_value(&beta, *stream), dst_dd_i, ldc))); + } } GGML_UNUSED(dst); GGML_UNUSED(src1_ddq_i); @@ -2697,7 +2715,7 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst, +static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst, const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23, size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3, int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) { @@ -2713,7 +2731,7 @@ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::h const uint8_t * src0_bytes = reinterpret_cast(src0_as_f16); const uint8_t * src1_bytes = reinterpret_cast(src1_as_f16); - uint8_t * dst_bytes = reinterpret_cast(dst); + uint8_t * dst_bytes = static_cast(dst); ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03; ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13; @@ -2726,6 +2744,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_TENSOR_BINARY_OP_LOCALS @@ -2766,7 +2785,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons } ggml_sycl_pool_alloc dst_f16(ctx.pool()); - char * dst_t = reinterpret_cast(dst_ddf); dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float; dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float; @@ -2783,42 +2801,83 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); + GGML_ASSERT(ne01 == static_cast(nb1/nb0)); + GGML_ASSERT(ne10 == ne00); // broadcast factors const int64_t r2 = ne12 / ne02; const int64_t r3 = ne13 / ne03; - if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { - // there is no broadcast and src0, src1 are contiguous across dims 2, 3 - SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans, - oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, - src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00, - src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t, - mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type))); - } else { - const int ne23 = ne12 * ne13; - - ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2 * ne23); - ggml_sycl_pool_alloc ptrs_dst(ctx.pool(), 1 * ne23); - ggml_sycl_pool_alloc> matrix_info(ctx.host_pool(), 1); - - sycl::range<3> block_dims(1, ne12, ne13); - queue->submit([&](sycl::handler & cgh) { - const void ** ptrs_src_get = ptrs_src.get(); - void ** ptrs_dst_get = ptrs_dst.get(); - size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half); - size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half); - cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02, - nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1); +#if GGML_SYCL_DNNL + if (!g_ggml_sycl_disable_dnn) { + auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12] + (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) { + + DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10, + src1, DnnlGemmWrapper::to_dt(), s11, 1, s12, + src0, DnnlGemmWrapper::to_dt(), 1, nb01/nb00, nb02/nb00, + dst, DnnlGemmWrapper::to_dt(), queue, batches_a, batches_b); + }; + + if (r2 == 1 && r3 == 1) { + if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { + dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03); + } + else { + for (int64_t ie03 = 0; ie03 < ne03; ++ie03) { + const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes + const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13; + float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float)); + dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02); + } + } + } else { + // iterate over batches from smaller set of matrices (matrix 0) + for (int64_t ie02 = 0; ie02 < ne02; ++ie02) { + for (int64_t ie03 = 0; ie03 < ne03; ++ie03) { + const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half)); + const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3; + float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float)); + dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1); + } + } + } + } + else +#endif + { + if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { + // there is no broadcast and src0, src1 are contiguous across dims 2, 3 + SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans, + oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, + src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00, + src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf, + mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type))); + } else { + const int ne23 = ne12 * ne13; + + ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2 * ne23); + ggml_sycl_pool_alloc ptrs_dst(ctx.pool(), 1 * ne23); + ggml_sycl_pool_alloc> matrix_info(ctx.host_pool(), 1); + + sycl::range<3> block_dims(1, ne12, ne13); + queue->submit([&](sycl::handler & cgh) { + const void ** ptrs_src_get = ptrs_src.get(); + void ** ptrs_dst_get = ptrs_dst.get(); + size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half); + size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half); + cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02, + nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1); + }); }); - }); - SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, - (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, - (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta, - (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get()))); + SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( + *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, + (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, + (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta, + (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get()))); + } } } catch (const sycl::exception & exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; @@ -2841,6 +2900,8 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: return true; + case GGML_TYPE_Q4_K: + return !g_ggml_sycl_prioritize_dmmv; default: return false; } @@ -2858,6 +2919,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) { inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_K: return true; default: return false; @@ -2883,16 +2945,16 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) { } } -static void reorder_qw(char *data_device, const int ncols, const int nrows, - size_t size, size_t offset, dpct::queue_ptr stream) { - auto tmp_buf = sycl::malloc_shared(size, *stream); +static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, + dpct::queue_ptr stream) { + auto * tmp_buf = sycl::malloc_shared(size, *stream); SYCL_CHECK( CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size) .wait())); GGML_ASSERT((size % sizeof(block_q4_0) == 0)); GGML_ASSERT((offset % sizeof(block_q4_0) == 0)); int offset_blks = offset / sizeof(block_q4_0); - auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2; + auto qs_ptr = data_device + offset_blks * QK4_0 / 2; auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks; stream->parallel_for( @@ -2906,18 +2968,59 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows, *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j]; } *(d_ptr + ib) = x[ib].d; - }); + }).wait_and_throw(); + + sycl::free(tmp_buf, *stream); +} + +static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { + GGML_ASSERT(size % sizeof(block_q4_K) == 0); + GGML_ASSERT(offset % sizeof(block_q4_K) == 0); + + const int nblocks = size / sizeof(block_q4_K); + + auto * tmp_buf = sycl::malloc_shared(size, *stream); + SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait())); + + auto * qs_ptr = data_device; + auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks; + auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks); + + stream->parallel_for(nblocks, [=](auto i) { + const block_q4_K * x = (const block_q4_K *) tmp_buf; + const int ib = i; + + for (int j = 0; j < QK_K / 2; ++j) { + qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j]; + } + + for (int j = 0; j < K_SCALE_SIZE; ++j) { + scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j]; + } + + dm_ptr[ib] = x[ib].dm; + }).wait_and_throw(); sycl::free(tmp_buf, *stream); } static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { - char*data_device = (char*)src0->data; + uint8_t * data_device = (uint8_t *) src0->data; size_t ncols = src0->ne[0]; size_t nrows = src0->ne[1]; size_t size = ggml_nbytes(src0); - reorder_qw(data_device, ncols, nrows, size, 0, stream); + switch (src0->type) { + case GGML_TYPE_Q4_0: + reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); + break; + case GGML_TYPE_Q4_K: + reorder_qw_q4_k(data_device, size, 0, stream); + break; + default: + GGML_ABORT("reorder_qw() called with unsupported type"); + break; + } } static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) { @@ -2960,8 +3063,18 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering } -static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && + src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; +} + +static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && + src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; +} + +static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer); int64_t min_compute_capability = INT_MAX; @@ -2984,13 +3097,9 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor } // check data types and tensor shapes for custom matrix multiplication kernels: - bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; + bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst); - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst); bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; @@ -3713,7 +3822,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_ return GGML_STATUS_SUCCESS; } - sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream())); + sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}}); + model_sycl_graph.begin_recording(*(sycl_ctx->stream())); ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); model_sycl_graph.end_recording(); diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 3cade1a42a6..23eeb74da0d 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -24,6 +24,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r const int blocks_per_row = ncols / block_traits::qk; constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi); constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq; + const int nblocks = nrows * (ncols / block_traits::qk); static_assert(blocks_per_subgroup > 0); static_assert(block_elements_per_subgroup > 0); @@ -45,7 +46,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r // x block quant index when casting the quants to int const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup); - partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs); + partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks); } } @@ -739,6 +740,27 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, } } +static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, + nrows, nd_item); + }); + }); +} + + static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -1035,7 +1057,12 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } else { + mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q5_K: mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index a74e30526c1..88ec13ea269 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -56,6 +56,28 @@ template <> struct block_q_t { static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } }; +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK_K; + static constexpr uint32_t qi = QI4_K; + static constexpr uint32_t qr = QR4_K; + static constexpr uint32_t vdr_mmvq = 2; + }; + + static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); } + + static constexpr int get_d_offset(int nrows, int ncols, const int block_index) { + auto nblocks = (nrows * (ncols / traits::qk)); + return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)); + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } + + constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; } + + constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; } +}; + } // namespace ggml_sycl_reordered #endif // GGML_SYCL_QUANTS_HPP diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index cbf664fcf28..ed369931346 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -285,7 +285,7 @@ template <> struct reorder_vec_dot_q_sycl { } __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, - const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const block_q8_1 * __restrict__ bq8_1, const int & iqs, int /* nblocks */) { const uint8_t * bq4_0 = static_cast(vbq) + ibx_offset; const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset)); int v[q4_0_traits::vdr_mmvq]; @@ -303,6 +303,67 @@ template <> struct reorder_vec_dot_q_sycl { }; }; +static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales, + const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1, + const int & iqs) { + int v[2]; + int u[2 * QR4_K]; + float d8[QR4_K]; + + v[0] = q4[0]; + v[1] = q4[4]; + + uint16_t aux[2]; + const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2; + if (j < 2) { + aux[0] = scales[j + 0] & 0x3f3f; + aux[1] = scales[j + 2] & 0x3f3f; + } else { + aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2); + aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2); + } + + const uint8_t * sc = (const uint8_t *) aux; + const uint8_t * m = sc + 2; + + const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); + + for (int i = 0; i < QR4_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = bq8i->ds[0]; + + const int * q8 = (const int *) bq8i->qs + ((iqs / 2) % 4); + u[2 * i + 0] = q8[0]; + u[2 * i + 1] = q8[4]; + } + + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, dm, d8); +} + +template <> struct reorder_vec_dot_q_sycl { + static constexpr ggml_type gtype = GGML_TYPE_Q4_K; + + using q4_k_block = ggml_sycl_reordered::block_q_t; + using q4_k_traits = typename q4_k_block::traits; + + float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, + const block_q8_1 * __restrict__ bq8_1, const int & iqs, int nblocks) { + const int ib = ibx_offset / (QK_K / 2); + + const uint8_t * base = static_cast(vbq); + const uint8_t * qs = base + ibx_offset; + const int total_qs_bytes = nblocks * (QK_K / 2); + const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE; + const ggml_half2 * dms = reinterpret_cast(base + d_offset); + + const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); + const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); + const uint16_t * scales = (const uint16_t *) scs; + + return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs); + } +}; + #define VDR_Q4_0_Q8_1_MMVQ 2 #define VDR_Q4_0_Q8_1_MMQ 4 @@ -649,52 +710,17 @@ vec_dot_q3_K_q8_1(const void *__restrict__ vbq, return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); } -static __dpct_inline__ float -vec_dot_q4_K_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs) { - +static __dpct_inline__ float vec_dot_q4_K_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, + const int & iqs) { #ifndef GGML_QKK_64 - const block_q4_K * bq4_K = (const block_q4_K *) vbq; - - int v[2]; - int u[2*QR4_K]; - float d8[QR4_K]; - // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6 - const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2)); - - // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12 - // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44 - // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 - // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 - - const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); - v[0] = q4[0]; - v[1] = q4[4]; - - const uint16_t * scales = (const uint16_t *)bq4_K->scales; - uint16_t aux[2]; - const int j = bq8_offset/2; - if (j < 2) { - aux[0] = scales[j+0] & 0x3f3f; - aux[1] = scales[j+2] & 0x3f3f; - } else { - aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); - aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); - } - const uint8_t * sc = (const uint8_t *)aux; - const uint8_t * m = sc + 2; - - for (int i = 0; i < QR4_K; ++i) { - const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - d8[i] = bq8i->ds[0]; + const block_q4_K * bq4_K = (const block_q4_K *) vbq; - const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); - u[2*i+0] = q8[0]; - u[2*i+1] = q8[4]; - } + const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); + const int * q4 = (const int *) (bq4_K->qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); + const uint16_t * scales = (const uint16_t *) bq4_K->scales; - return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); + return vec_dot_q4_K_q8_1_common(q4, scales, bq4_K->dm, bq8_1, iqs); #else diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 31816219c06..662f1377107 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -15,6 +15,32 @@ function(detect_host_compiler) set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE) endfunction() +# Function to test shader extension support +# Parameters: +# EXTENSION_NAME - Name of the extension to test (e.g., "GL_EXT_integer_dot_product") +# TEST_SHADER_FILE - Path to the test shader file +# RESULT_VARIABLE - Name of the variable to set (ON/OFF) based on test result +function(test_shader_extension_support EXTENSION_NAME TEST_SHADER_FILE RESULT_VARIABLE) + execute_process( + COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${TEST_SHADER_FILE}" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error + ) + + if (${glslc_error} MATCHES ".*extension not supported: ${EXTENSION_NAME}.*") + message(STATUS "${EXTENSION_NAME} not supported by glslc") + set(${RESULT_VARIABLE} OFF PARENT_SCOPE) + else() + message(STATUS "${EXTENSION_NAME} supported by glslc") + set(${RESULT_VARIABLE} ON PARENT_SCOPE) + add_compile_definitions(${RESULT_VARIABLE}) + + # Ensure the extension support is forwarded to vulkan-shaders-gen + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -D${RESULT_VARIABLE}=ON) + set(VULKAN_SHADER_GEN_CMAKE_ARGS "${VULKAN_SHADER_GEN_CMAKE_ARGS}" PARENT_SCOPE) + endif() +endfunction() + if (Vulkan_FOUND) message(STATUS "Vulkan found") @@ -23,69 +49,40 @@ if (Vulkan_FOUND) ../../include/ggml-vulkan.h ) - # Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) - - if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*") - message(STATUS "GL_KHR_cooperative_matrix not supported by glslc") - set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_KHR_cooperative_matrix supported by glslc") - set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) - endif() - - # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) + set(VULKAN_SHADER_GEN_CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} + -DCMAKE_RUNTIME_OUTPUT_DIRECTORY=${CMAKE_RUNTIME_OUTPUT_DIRECTORY} + ) - if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*") - message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc") - set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_NV_cooperative_matrix2 supported by glslc") - set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + set(VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS "") + if (CMAKE_BUILD_TYPE AND CMAKE_BUILD_TYPE MATCHES "Debug|Release|MinSizeRel|RelWithDebInfo") + list(APPEND VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS --config=${CMAKE_BUILD_TYPE}) endif() - # Compile a test shader to determine whether GL_EXT_integer_dot_product is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) + # Test all shader extensions + test_shader_extension_support( + "GL_KHR_cooperative_matrix" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" + "GGML_VULKAN_COOPMAT_GLSLC_SUPPORT" + ) - if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*") - message(STATUS "GL_EXT_integer_dot_product not supported by glslc") - set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_EXT_integer_dot_product supported by glslc") - set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - endif() + test_shader_extension_support( + "GL_NV_cooperative_matrix2" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" + "GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT" + ) - # Compile a test shader to determine whether GL_EXT_bfloat16 is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) + test_shader_extension_support( + "GL_EXT_integer_dot_product" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp" + "GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT" + ) - if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*") - message(STATUS "GL_EXT_bfloat16 not supported by glslc") - set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_EXT_bfloat16 supported by glslc") - set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) - endif() + test_shader_extension_support( + "GL_EXT_bfloat16" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp" + "GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT" + ) target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) @@ -124,16 +121,8 @@ if (Vulkan_FOUND) add_compile_definitions(GGML_VULKAN_RUN_TESTS) endif() - if (NOT CMAKE_CROSSCOMPILING) - add_subdirectory(vulkan-shaders) - if (MSVC) - foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES}) - string(TOUPPER ${CONFIG} CONFIG) - set_target_properties(vulkan-shaders-gen PROPERTIES - RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) - endforeach() - endif() - else() + # Set up toolchain for host compilation whether cross-compiling or not + if (CMAKE_CROSSCOMPILING) if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN) set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN}) else() @@ -146,25 +135,31 @@ if (Vulkan_FOUND) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY) set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake) endif() - message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}") + else() + # For non-cross-compiling, use empty toolchain (use host compiler) + set(HOST_CMAKE_TOOLCHAIN_FILE "") + endif() - include(ExternalProject) - # Native build through ExternalProject_Add - ExternalProject_Add( - vulkan-shaders-gen - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders - CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE} - -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} - -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT} - -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT} - -DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT} - -DGGML_VULKAN_BFLOAT16_GLSLC_SUPPORT=${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT} - BUILD_COMMAND ${CMAKE_COMMAND} --build . - INSTALL_COMMAND ${CMAKE_COMMAND} --install . - INSTALL_DIR ${CMAKE_BINARY_DIR} - ) - ExternalProject_Add_StepTargets(vulkan-shaders-gen build install) + # Always use ExternalProject_Add approach + include(ExternalProject) + + # Add toolchain file if cross-compiling + if (CMAKE_CROSSCOMPILING) + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}) + message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}") endif() + + # Native build through ExternalProject_Add + ExternalProject_Add( + vulkan-shaders-gen + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders + CMAKE_ARGS ${VULKAN_SHADER_GEN_CMAKE_ARGS} + BUILD_COMMAND ${CMAKE_COMMAND} --build . ${VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS} + INSTALL_COMMAND ${CMAKE_COMMAND} --install . + INSTALL_DIR ${CMAKE_BINARY_DIR} + ) + ExternalProject_Add_StepTargets(vulkan-shaders-gen build install) + set (_ggml_vk_host_suffix $,.exe,>) set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix}) set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp) @@ -175,9 +170,8 @@ if (Vulkan_FOUND) file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp") set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen) - if (CMAKE_CROSSCOMPILING) - set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install) - endif() + # Add build and install dependencies for all builds + set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install) add_custom_command( OUTPUT ${_ggml_vk_header} diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e2b357fdc15..fe3669b462c 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -288,6 +288,9 @@ struct vk_device_struct { bool coopmat_acc_f32_support {}; bool coopmat_acc_f16_support {}; bool coopmat_bf16_support {}; + bool coopmat_support_16x16x16_f16acc {}; + bool coopmat_support_16x16x16_f32acc {}; + bool coopmat1_fa_support {}; uint32_t coopmat_m; uint32_t coopmat_n; uint32_t coopmat_k; @@ -410,6 +413,13 @@ struct vk_device_struct { vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; @@ -1588,19 +1598,36 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events ); } +enum FaCodePath { + FA_SCALAR, + FA_COOPMAT1, + FA_COOPMAT2, +}; + // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; static constexpr uint32_t scalar_flash_attention_num_large_rows = 8; -static uint32_t get_fa_num_small_rows(bool scalar) { - return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows; +// The FA coopmat1 shader assumes 16x16x16 matrix multiply support. +// 128 threads split into four subgroups, each subgroup does 1/4 +// of the Bc dimension. +static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16; +static constexpr uint32_t scalar_flash_attention_Bc = 64; +static constexpr uint32_t scalar_flash_attention_workgroup_size = 128; + +static uint32_t get_fa_num_small_rows(FaCodePath path) { + if (path == FA_COOPMAT2) { + return flash_attention_num_small_rows; + } else { + return scalar_flash_attention_num_small_rows; + } } -static std::array fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { +static std::array fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { GGML_UNUSED(clamp); - if (scalar) { + if (path == FA_SCALAR) { if (small_rows) { return {scalar_flash_attention_num_small_rows, 64}; } else { @@ -1608,9 +1635,17 @@ static std::array fa_rows_cols(bool scalar, uint32_t D, uint32_t cl } } + if (path == FA_COOPMAT1) { + if (small_rows) { + return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc}; + } else { + return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc}; + } + } + // small rows, large cols if (small_rows) { - return {get_fa_num_small_rows(scalar), 32}; + return {get_fa_num_small_rows(FA_COOPMAT2), 32}; } // small cols to reduce register count @@ -1907,17 +1942,19 @@ static void ggml_vk_load_shaders(vk_device& device) { parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; - auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { - return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1}; + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { + return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1}; }; - auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // can't use 256 for D==80. // For scalar, use 128 (arbitrary) - uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128); - auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows); + uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) + ? scalar_flash_attention_workgroup_size + : ((small_rows && (D % 32) == 0) ? 256 : 128); + auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows); // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. @@ -1929,36 +1966,43 @@ static void ggml_vk_load_shaders(vk_device& device) { return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split}; }; -#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ - -#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256) - - CREATE_FA(GGML_TYPE_F16, f16, true, ) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, ) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, ) +#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + +#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256) + + CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->coopmat1_fa_support) { + CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) + } +#endif #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { - CREATE_FA(GGML_TYPE_F16, f16, false, _cm2) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2) + CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2) } #endif #undef CREATE_FA2 @@ -2041,17 +2085,17 @@ static void ggml_vk_load_shaders(vk_device& device) { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ // Create 2 variants, {f16,f32} accumulator #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ @@ -3009,6 +3053,11 @@ static vk_device ggml_vk_get_device(size_t idx) { #if defined(VK_KHR_cooperative_matrix) device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; + + // coopmat1 fa shader currently assumes 32 invocations per subgroup + device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support && + device->subgroup_size_control && device->subgroup_min_size <= 32 && + device->subgroup_max_size >= 32; #endif if (coopmat2_support) { @@ -3143,6 +3192,9 @@ static vk_device ggml_vk_get_device(size_t idx) { // Only enable if shape is identical device->coopmat_acc_f32_support = true; } + if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { + device->coopmat_support_16x16x16_f32acc = true; + } } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 && (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) { // coopmat sizes not set yet @@ -3155,6 +3207,9 @@ static vk_device ggml_vk_get_device(size_t idx) { // Only enable if shape is identical device->coopmat_acc_f16_support = true; } + if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { + device->coopmat_support_16x16x16_f16acc = true; + } } } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 && (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 && @@ -5688,6 +5743,36 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) { + // Needs to be kept up to date on shader changes + const uint32_t wg_size = scalar_flash_attention_workgroup_size; + const uint32_t Br = scalar_flash_attention_num_large_rows; + const uint32_t Bc = scalar_flash_attention_Bc; + + const uint32_t acctype = f32acc ? 4 : 2; + const uint32_t f16vec4 = 8; + + const uint32_t tmpsh = wg_size * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * acctype; + + const uint32_t Qf = Br * (D / 4 + 2) * f16vec4; + + const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; + const uint32_t sfsh = Bc * sfshstride * acctype; + + const uint32_t kshstride = D / 4 + 2; + const uint32_t ksh = Bc * kshstride * f16vec4; + + const uint32_t slope = Br * sizeof(float); + + const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); + + return supported; +} + static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; @@ -5738,7 +5823,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx assert(q->type == GGML_TYPE_F32); assert(k->type == v->type); - bool scalar = !ctx->device->coopmat2; + FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 : + ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + + if (path == FA_COOPMAT1) { + const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || + (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); + + const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32); + + if (!coopmat_shape_supported || !coopmat_shmem_supported) { + path = FA_SCALAR; + } + } uint32_t gqa_ratio = 1; uint32_t qk_ratio = neq2 / nek2; @@ -5746,9 +5843,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; - // For scalar FA, we can use the "large" size to accommodate qga. - // For coopmat FA, we always use the small size (which is still pretty large for gqa). - const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false); + // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. + // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). + uint32_t max_gqa; + switch (path) { + case FA_SCALAR: + case FA_COOPMAT1: + // We may switch from coopmat1 to scalar, so use the scalar limit for both + max_gqa = scalar_flash_attention_num_large_rows; + break; + case FA_COOPMAT2: + max_gqa = get_fa_num_small_rows(FA_COOPMAT2); + break; + default: + GGML_ASSERT(0); + } if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { @@ -5761,11 +5870,23 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } vk_pipeline *pipelines; - // XXX TODO other backends may be changing accumulator precision to default to f32 soon - bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32; - bool small_rows = N <= get_fa_num_small_rows(scalar); + bool small_rows = N <= get_fa_num_small_rows(path); + + // coopmat1 does not actually support "small rows" (it needs 16 rows). + // So use scalar instead. + if (small_rows && path == FA_COOPMAT1) { + path = FA_SCALAR; + } - if (scalar) { + // scalar is faster than coopmat2 when N==1 + if (N == 1 && path == FA_COOPMAT2) { + path = FA_SCALAR; + } + + bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; + + switch (path) { + case FA_SCALAR: switch (D) { case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; @@ -5777,7 +5898,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(!"unsupported D value"); return; } - } else { + break; + case FA_COOPMAT1: + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break; + default: + GGML_ASSERT(!"unsupported D value"); + return; + } + break; + case FA_COOPMAT2: switch (D) { case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break; case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break; @@ -5789,6 +5924,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(!"unsupported D value"); return; } + break; + default: + GGML_ASSERT(0); } assert(pipelines); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index ad13f69b3fb..e60e9d1e5b5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -5,18 +5,35 @@ find_package (Threads REQUIRED) if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat glslc support") endif() if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat2 glslc support") endif() if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + message(STATUS "Enabling dot glslc support") endif() if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + message(STATUS "Enabling bfloat16 glslc support") endif() + set(TARGET vulkan-shaders-gen) add_executable(${TARGET} vulkan-shaders-gen.cpp) install(TARGETS ${TARGET} RUNTIME) target_compile_features(${TARGET} PRIVATE cxx_std_17) target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) + +# Configure output directories for MSVC builds +if(MSVC) + # Get the main project's runtime output directory if possible + if(DEFINED CMAKE_RUNTIME_OUTPUT_DIRECTORY) + foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES}) + string(TOUPPER ${CONFIG} CONFIG) + set_target_properties(${TARGET} PROPERTIES + RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) + endforeach() + endif() +endif() diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index e6545160d53..ce230a8f7d9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -9,59 +9,13 @@ #extension GL_KHR_shader_subgroup_shuffle : enable #include "types.comp" +#include "flash_attn_base.comp" -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (constant_id = 1) const uint32_t Br = 1; -layout (constant_id = 2) const uint32_t Bc = 32; -layout (constant_id = 3) const uint32_t D = 32; - -layout (constant_id = 5) const uint32_t D_split = 16; const uint32_t D_per_thread = D / D_split; -const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split; +const uint32_t cols_per_iter = WorkGroupSize / D_split; const uint32_t cols_per_thread = Bc / cols_per_iter; -layout (push_constant) uniform parameter { - uint32_t N; - uint32_t KV; - - uint32_t ne1; - uint32_t ne2; - uint32_t ne3; - - uint32_t neq2; - uint32_t neq3; - uint32_t nek2; - uint32_t nek3; - uint32_t nev2; - uint32_t nev3; - uint32_t nem1; - - uint32_t nb01; - uint32_t nb02; - uint32_t nb03; - uint32_t nb11; - uint32_t nb12; - uint32_t nb13; - uint32_t nb21; - uint32_t nb22; - uint32_t nb23; - uint32_t nb31; - - float scale; - float max_bias; - float logit_softcap; - - uint32_t mask; - uint32_t n_head_log2; - float m0; - float m1; - - uint32_t gqa_ratio; - uint32_t split_kv; - uint32_t k_num; -} p; layout (binding = 0) readonly buffer Q {float data_q[];}; layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; @@ -70,39 +24,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; layout (binding = 2) readonly buffer V {float16_t data_v[];}; layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; -layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; - -#if defined(A_TYPE_PACKED16) -#define BINDING_IDX_K 0 -#define BINDING_IDX_V 1 -layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; -#endif - -#if defined(DATA_A_Q4_0) -#define BLOCK_BYTE_SIZE 18 - -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); -} -#endif - -#if defined(DATA_A_Q8_0) -#define BLOCK_BYTE_SIZE 34 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); -} -#endif - -#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. @@ -113,29 +34,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } -// Store column zero. This is used to save per-row m and L values for split_k. -ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - if (r < N && c == 0) { - uint32_t offset = iq2 + r; - data_o[o_offset + offset] = D_TYPE(elem); - } - return elem; -} - -// Load the slope matrix, indexed by Q's dimension 2. -ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) -{ - const uint32_t h = iq2 + (r % p.gqa_ratio); - - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); - - return ACC_TYPE(pow(base, ACC_TYPE(exph))); -} - -shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; -shared vec4 tmpshv4[gl_WorkGroupSize.x]; +shared FLOAT_TYPE tmpsh[WorkGroupSize]; +shared vec4 tmpshv4[WorkGroupSize]; shared float masksh[Bc][Br]; shared vec4 Qf[Br][D / 4]; @@ -145,58 +45,12 @@ void main() { init_iq_shmem(gl_WorkGroupSize); #endif - const uint32_t tid = gl_LocalInvocationIndex; - const uint32_t N = p.N; - const uint32_t KV = p.KV; + init_indices(); + const uint32_t tid = gl_LocalInvocationIndex; const uint32_t d_tid = gl_LocalInvocationIndex % D_split; const uint32_t col_tid = gl_LocalInvocationIndex / D_split; - uint32_t i = gl_WorkGroupID.x; - uint32_t split_k_index = 0; - - if (p.k_num > 1) { - i = 0; - split_k_index = gl_WorkGroupID.x; - } - - const uint32_t Tr = CEIL_DIV(N, Br); - - const uint32_t start_j = split_k_index * p.split_kv / Bc; - const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); - - // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. - // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. - const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; - const uint32_t iq3 = gl_WorkGroupID.z; - - // broadcast factors - const uint32_t rk2 = p.neq2/p.nek2; - const uint32_t rk3 = p.neq3/p.nek3; - - const uint32_t rv2 = p.neq2/p.nev2; - const uint32_t rv3 = p.neq3/p.nev3; - - // k indices - const uint32_t ik3 = iq3 / rk3; - const uint32_t ik2 = iq2 / rk2; - - // v indices - const uint32_t iv3 = iq3 / rv3; - const uint32_t iv2 = iq2 / rv2; - - // nb?1 are already divided by the type size and are in units of elements. - // When using grouped query attention, Q is indexed by iq2, so the stride - // should be nb02 (which is in bytes). - uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; - uint32_t k_stride = p.nb11; - uint32_t v_stride = p.nb21; - // When using grouped query attention, all rows use the same mask (stride 0). - // "p.gqa_ratio >> 16" is just a roundabout way of writing zero - // that prevents the compiler from folding the "&" through the select - // and breaking the alignment detection. - uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp new file mode 100644 index 00000000000..61d90e2d8ed --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -0,0 +1,162 @@ + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 0) const uint32_t WorkGroupSize = 128; +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t D = 32; +layout (constant_id = 4) const uint32_t Clamp = 0; +layout (constant_id = 5) const uint32_t D_split = 16; + + +layout (push_constant) uniform parameter { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; +} p; + +layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; + +#if defined(A_TYPE_PACKED16) +#define BINDING_IDX_K 0 +#define BINDING_IDX_V 1 +layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; +#endif + +#if defined(DATA_A_Q4_0) +#define BLOCK_BYTE_SIZE 18 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); +} +#endif + +#if defined(DATA_A_Q8_0) +#define BLOCK_BYTE_SIZE 34 +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + +uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, + iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, + q_stride, k_stride, v_stride, m_stride; + +void init_indices() +{ + N = p.N; + KV = p.KV; + + i = gl_WorkGroupID.x; + split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + + Tr = CEIL_DIV(N, Br); + + start_j = split_k_index * p.split_kv / Bc; + end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); + + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + iq2 = gl_WorkGroupID.y * p.gqa_ratio; + iq3 = gl_WorkGroupID.z; + + // broadcast factors + rk2 = p.neq2/p.nek2; + rk3 = p.neq3/p.nek3; + + rv2 = p.neq2/p.nev2; + rv3 = p.neq3/p.nev3; + + // k indices + ik3 = iq3 / rk3; + ik2 = iq2 / rk2; + + // v indices + iv3 = iq3 / rv3; + iv2 = iq2 / rv2; + + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; + k_stride = p.nb11; + v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp new file mode 100644 index 00000000000..da478be24fb --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -0,0 +1,360 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable + +#include "types.comp" +#include "flash_attn_base.comp" + +const uint32_t D_per_thread = D / D_split; +const uint32_t row_split = 4; +const uint32_t rows_per_thread = Br / row_split; +const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; +const uint32_t cols_per_thread = Bc / cols_per_iter; + + +layout (binding = 0) readonly buffer Q {float data_q[];}; +layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; +layout (binding = 1) readonly buffer K {float16_t data_k[];}; +layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; +layout (binding = 2) readonly buffer V {float16_t data_v[];}; +layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 3) readonly buffer M {float16_t data_m[];}; + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * D + c; + data_o[o_offset + offset] = D_TYPE(elem); + return elem; +} + +// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd +const uint32_t MatBr = 16; +const uint32_t MatBc = 16; + +shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; +shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; + +const uint32_t qstride = D / 4 + 2; // in units of f16vec4 +shared f16vec4 Qf[Br * qstride]; + +// Avoid padding for D==256 to make it fit in 48KB shmem. +const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; +shared ACC_TYPE sfsh[Bc * sfshstride]; + +const uint32_t kshstride = D / 4 + 2; // in units of f16vec4 +shared f16vec4 ksh[Bc * kshstride]; + +shared float slope[Br]; + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + init_indices(); + + const uint32_t tid = gl_LocalInvocationIndex; + + const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; + const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; + const uint32_t d_tid = gl_LocalInvocationIndex % D_split; + const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + +#define tile_row(r) (row_tid * rows_per_thread + (r)) + + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + + [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (D / 4); + uint32_t r = (idx + tid) / (D / 4); + if (r < Br && d < D / 4 && + i * Br + r < N) { + Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + } + } + barrier(); + + ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = ACC_TYPEV4(0.0); + } + } + + float Lf[rows_per_thread], Mf[rows_per_thread]; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lf[r] = 0; + Mf[r] = NEG_FLT_MAX_OVER_2; + } + + // ALiBi + if (p.max_bias > 0.0f) { + if (tid < Br) { + uint r = tid; + slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + } + barrier(); + } else { + if (tid < Br) { + uint r = tid; + slope[r] = 1.0; + } + barrier(); + } + +#if BLOCK_SIZE > 1 + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; +#else + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; +#endif + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + [[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (D / 4); + uint32_t c = (idx + tid) / (D / 4); + if (c < Bc && d < D / 4) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); +#else + f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); +#endif + + ksh[c * kshstride + d] = K_Tf; + } + } + barrier(); + + // K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br + // Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 + // This is written transposed in order to allow for N being 8 if implementations need it + coopmat SfMat = coopmat(0); + coopmat KMat; + coopmat QMat; + + for (uint32_t d = 0; d < D / 16; ++d) { + coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); + + uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; + coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + + SfMat = coopMatMulAdd(KMat, QMat, SfMat); + } + + uint coord = gl_SubgroupID * MatBc * sfshstride; + coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor); + barrier(); + + if (p.logit_softcap != 0.0f) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / Br; + uint32_t r = (idx + tid) % Br; + if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { + sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r])); + } + } + barrier(); + } + + if (p.mask != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { + sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)])); + } + } + barrier(); + } + + float eMf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride]; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride])); + } + float Moldf = Mf[r]; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf, Moldf); + eMf[r] = exp(Moldf - Mf[r]); + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = float16_t(eMf[r]) * Of[r][d]; + } + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lf[r] = eMf[r]*Lf[r]; + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + float Pf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); + Lf[r] += Pf[r]; + } + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf); + } + } + } + + barrier(); + } + + // reduce across threads + + float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE M = Mf[r]; + tmpsh[tid] = M; + // Compute max across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + M = max(M, tmpsh[tid ^ s]); + barrier(); + tmpsh[tid] = M; + barrier(); + } + rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Moldf[r] = Mf[r]; + + // M = max(rowmax, Mold) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf[r], Moldf[r]); + eMf[r] = exp(Moldf[r] - Mf[r]); + + Lf[r] = eMf[r]*Lf[r]; + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE L = Lf[r]; + tmpsh[tid] = L; + // Compute sum across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + L += tmpsh[tid ^ s]; + barrier(); + tmpsh[tid] = L; + barrier(); + } + Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + + Of[r][d] = float16_t(eMf[r]) * Of[r][d]; + tmpshv4[tid] = Of[r][d]; + + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + Of[r][d] += tmpshv4[tid ^ s]; + barrier(); + tmpshv4[tid] = Of[r][d]; + barrier(); + } + Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + } + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + uint32_t o_offset = D * p.ne1 * split_k_index; + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + } + } + } + } + + o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + + return; + } + + float Lfrcp[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lfrcp[r] = 1.0 / Lf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] *= float16_t(Lfrcp[r]); + } + } + + uint32_t o_offset = iq3*p.ne2*p.ne1; + + if (p.gqa_ratio > 1) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + } + } + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (i * Br + tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + } + } + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index b926a578ade..6acf67a03a4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -18,62 +18,12 @@ #include "types.comp" #include "dequant_funcs_cm2.comp" - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (constant_id = 1) const uint32_t Br = 32; -layout (constant_id = 2) const uint32_t Bc = 32; -layout (constant_id = 3) const uint32_t D = 32; -layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV; - -layout (push_constant) uniform parameter { - uint32_t N; - uint32_t KV; - - uint32_t ne1; - uint32_t ne2; - uint32_t ne3; - - uint32_t neq2; - uint32_t neq3; - uint32_t nek2; - uint32_t nek3; - uint32_t nev2; - uint32_t nev3; - uint32_t nem1; - - uint32_t nb01; - uint32_t nb02; - uint32_t nb03; - uint32_t nb11; - uint32_t nb12; - uint32_t nb13; - uint32_t nb21; - uint32_t nb22; - uint32_t nb23; - uint32_t nb31; - - float scale; - float max_bias; - float logit_softcap; - - uint32_t mask; - uint32_t n_head_log2; - float m0; - float m1; - - uint32_t gqa_ratio; - uint32_t split_kv; - uint32_t k_num; -} p; +#include "flash_attn_base.comp" layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; layout (binding = 2) readonly buffer V {uint8_t data_v[];}; layout (binding = 3) readonly buffer M {uint8_t data_m[];}; -layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; - -#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { return max(x, y); @@ -118,67 +68,12 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } -// Store column zero. This is used to save per-row m and L values for split_k. -ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - if (r < N && c == 0) { - uint32_t offset = iq2 + r; - data_o[o_offset + offset] = D_TYPE(elem); - } - return elem; -} - -// Load the slope matrix, indexed by Q's dimension 2. -ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) -{ - const uint32_t h = iq2 + (r % p.gqa_ratio); - - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); - - return ACC_TYPE(pow(base, ACC_TYPE(exph))); -} - void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif - const uint32_t N = p.N; - const uint32_t KV = p.KV; - - uint32_t i = gl_WorkGroupID.x; - uint32_t split_k_index = 0; - - if (p.k_num > 1) { - i = 0; - split_k_index = gl_WorkGroupID.x; - } - - const uint32_t Tr = CEIL_DIV(N, Br); - - const uint32_t start_j = split_k_index * p.split_kv / Bc; - const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); - - // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. - // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. - const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; - const uint32_t iq3 = gl_WorkGroupID.z; - - // broadcast factors - const uint32_t rk2 = p.neq2/p.nek2; - const uint32_t rk3 = p.neq3/p.nek3; - - const uint32_t rv2 = p.neq2/p.nev2; - const uint32_t rv3 = p.neq3/p.nev3; - - // k indices - const uint32_t ik3 = iq3 / rk3; - const uint32_t ik2 = iq2 / rk2; - - // v indices - const uint32_t iv3 = iq3 / rv3; - const uint32_t iv2 = iq2 / rv2; + init_indices(); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); @@ -195,17 +90,6 @@ void main() { tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); - // nb?1 are already divided by the type size and are in units of elements. - // When using grouped query attention, Q is indexed by iq2, so the stride - // should be nb02 (which is in bytes). - uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; - uint32_t k_stride = p.nb11; - uint32_t v_stride = p.nb21; - // When using grouped query attention, all rows use the same mask (stride 0). - // "p.gqa_ratio >> 16" is just a roundabout way of writing zero - // that prevents the compiler from folding the "&" through the select - // and breaking the alignment detection. - uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; // hint to the compiler that strides are aligned for the aligned variant of the shader if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index d196137eb29..9361e2ac83b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -215,7 +215,7 @@ static std::mutex compile_count_mutex; static std::condition_variable compile_count_cond; void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { - std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); + std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); std::string out_fname = join_paths(output_dir, name + ".spv"); std::string in_path = join_paths(input_dir, in_fname); @@ -424,6 +424,7 @@ void process_shaders() { // flash attention for (const auto& f16acc : {false, true}) { std::string acctype = f16acc ? "float16_t" : "float"; + std::string acctypev4 = f16acc ? "f16vec4" : "vec4"; for (const auto& tname : type_names) { if (tname == "f32") { @@ -440,6 +441,16 @@ void process_shaders() { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); } +#endif +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); + } #endif if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 8a6546240f4..d48adb9afb8 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -64,12 +64,17 @@ // precomputed f32 table for f16 (256 KB) (ggml-impl.h) float ggml_table_f32_f16[1 << 16]; -#if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \ - (!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH)) +#if defined(__linux__) || \ + defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \ + (defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH) + #include #include #include #include +#if defined(__linux__) +#include +#endif #if defined(__ANDROID__) #include @@ -133,10 +138,36 @@ static void ggml_print_backtrace(void) { if (GGML_NO_BACKTRACE) { return; } - char attach[32]; - snprintf(attach, sizeof(attach), "attach %d", getpid()); - int pid = fork(); - if (pid == 0) { +#if defined(__linux__) + FILE * f = fopen("/proc/self/status", "r"); + size_t size = 0; + char * line = NULL; + ssize_t length = 0; + while ((length = getline(&line, &size, f)) > 0) { + if (!strncmp(line, "TracerPid:", sizeof("TracerPid:") - 1) && + (length != sizeof("TracerPid:\t0\n") - 1 || line[length - 2] != '0')) { + // Already being debugged, and the breakpoint is the later abort() + free(line); + fclose(f); + return; + } + } + free(line); + fclose(f); + int lock[2] = { -1, -1 }; + (void) !pipe(lock); // Don't start gdb until after PR_SET_PTRACER +#endif + const int parent_pid = getpid(); + const int child_pid = fork(); + if (child_pid < 0) { // error + return; + } else if (child_pid == 0) { // child + char attach[32]; + snprintf(attach, sizeof(attach), "attach %d", parent_pid); +#if defined(__linux__) + close(lock[1]); + (void) !read(lock[0], lock, 1); +#endif // try gdb execlp("gdb", "gdb", "--batch", "-ex", "set style enabled on", @@ -149,18 +180,18 @@ static void ggml_print_backtrace(void) { execlp("lldb", "lldb", "--batch", "-o", "bt", "-o", "quit", - "-p", attach, + "-p", &attach[sizeof("attach ") - 1], (char *) NULL); - exit(EXIT_FAILURE); - } else { - int wstatus; - waitpid(pid, &wstatus, 0); - if (WIFEXITED(wstatus)) { - if (WEXITSTATUS(wstatus) == EXIT_FAILURE) { - // gdb failed, fallback to backtrace_symbols - ggml_print_backtrace_symbols(); - } - } + // gdb failed, fallback to backtrace_symbols + ggml_print_backtrace_symbols(); + _Exit(0); + } else { // parent +#if defined(__linux__) + prctl(PR_SET_PTRACER, child_pid); + close(lock[1]); + close(lock[0]); +#endif + waitpid(child_pid, NULL, 0); } } #else diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index 381a9c7dcbe..8667a80bd06 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -299,10 +299,10 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vectorversion)) { if (ctx->version == 1) { - fprintf(stderr, "%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__); + GGML_LOG_ERROR("%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__); ok = false; } if (ctx->version > GGUF_VERSION) { - fprintf(stderr, "%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n", + GGML_LOG_ERROR("%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n", __func__, ctx->version, GGUF_VERSION); ok = false; } @@ -363,7 +363,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par if (ok && gr.read(n_tensors)) { static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing"); if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) { - fprintf(stderr, "%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n", + GGML_LOG_ERROR("%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n", __func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info)); ok = false; } @@ -374,7 +374,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par if (ok && gr.read(n_kv)) { static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing"); if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) { - fprintf(stderr, "%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n", + GGML_LOG_ERROR("%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n", __func__, n_kv, SIZE_MAX/sizeof(gguf_kv)); ok = false; } @@ -383,7 +383,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par } if (!ok) { - fprintf(stderr, "%s: failed to read header\n", __func__); + GGML_LOG_ERROR("%s: failed to read header\n", __func__); gguf_free(ctx); return nullptr; } @@ -399,15 +399,15 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par try { ok = ok && gr.read(key); } catch (std::length_error &) { - fprintf(stderr, "%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i); + GGML_LOG_ERROR("%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i); ok = false; } catch (std::bad_alloc &) { - fprintf(stderr, "%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i); + GGML_LOG_ERROR("%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i); ok = false; } for (size_t j = 0; ok && j < ctx->kv.size(); ++j) { if (key == ctx->kv[j].key) { - fprintf(stderr, "%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i); + GGML_LOG_ERROR("%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i); ok = false; } } @@ -441,14 +441,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par case GGUF_TYPE_ARRAY: default: { - fprintf(stderr, "%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type); + GGML_LOG_ERROR("%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type); ok = false; } break; } } if (!ok) { - fprintf(stderr, "%s: failed to read key-value pairs\n", __func__); + GGML_LOG_ERROR("%s: failed to read key-value pairs\n", __func__); gguf_free(ctx); return nullptr; } @@ -458,7 +458,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx); if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) { - fprintf(stderr, "%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment); + GGML_LOG_ERROR("%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment); gguf_free(ctx); return nullptr; } @@ -474,14 +474,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par try { ok = ok && gr.read(name); } catch (std::length_error &) { - fprintf(stderr, "%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i); + GGML_LOG_ERROR("%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i); ok = false; } catch (std::bad_alloc &) { - fprintf(stderr, "%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i); + GGML_LOG_ERROR("%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i); ok = false; } if (name.length() >= GGML_MAX_NAME) { - fprintf(stderr, "%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME); + GGML_LOG_ERROR("%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME); ok = false; break; } @@ -490,7 +490,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // make sure there are no duplicate tensor names for (int64_t j = 0; ok && j < i; ++j) { if (strcmp(info.t.name, ctx->info[j].t.name) == 0) { - fprintf(stderr, "%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i); + GGML_LOG_ERROR("%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i); ok = false; break; } @@ -505,7 +505,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par uint32_t n_dims = -1; ok = ok && gr.read(n_dims); if (n_dims > GGML_MAX_DIMS) { - fprintf(stderr, "%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n", + GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n", __func__, info.t.name, n_dims, GGML_MAX_DIMS); ok = false; break; @@ -518,7 +518,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // check that all ne are non-negative if (info.t.ne[j] < 0) { - fprintf(stderr, "%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n", + GGML_LOG_ERROR("%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n", __func__, info.t.name, j, info.t.ne[j]); ok = false; break; @@ -530,7 +530,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par (INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) || (INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) { - fprintf(stderr, "%s: total number of elements in tensor '%s' with shape " + GGML_LOG_ERROR("%s: total number of elements in tensor '%s' with shape " "(%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") is >= %" PRIi64 "\n", __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX); ok = false; @@ -547,7 +547,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // check that tensor type is within defined range if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) { - fprintf(stderr, "%s: tensor '%s' has invalid ggml type %d (%s)\n", + GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d (%s)\n", __func__, info.t.name, info.t.type, ggml_type_name(info.t.type)); ok = false; break; @@ -557,7 +557,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // check that row size is divisible by block size if (blck_size == 0 || info.t.ne[0] % blck_size != 0) { - fprintf(stderr, "%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, " + GGML_LOG_ERROR("%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, " "not a multiple of block size (%" PRId64 ")\n", __func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size); ok = false; @@ -582,7 +582,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par } if (!ok) { - fprintf(stderr, "%s: failed to read tensor info\n", __func__); + GGML_LOG_ERROR("%s: failed to read tensor info\n", __func__); gguf_free(ctx); return nullptr; } @@ -590,7 +590,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // we require the data section to be aligned, so take into account any padding if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) { - fprintf(stderr, "%s: failed to seek to beginning of data section\n", __func__); + GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__); gguf_free(ctx); return nullptr; } @@ -604,9 +604,9 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par for (size_t i = 0; i < ctx->info.size(); ++i) { const gguf_tensor_info & ti = ctx->info[i]; if (ti.offset != ctx->size) { - fprintf(stderr, "%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n", + GGML_LOG_ERROR("%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n", __func__, ti.t.name, ti.offset, ctx->size); - fprintf(stderr, "%s: failed to read tensor data\n", __func__); + GGML_LOG_ERROR("%s: failed to read tensor data\n", __func__); gguf_free(ctx); return nullptr; } @@ -634,7 +634,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par *params.ctx = ggml_init(pdata); if (*params.ctx == nullptr) { - fprintf(stderr, "%s: failed to initialize ggml context for storing tensors\n", __func__); + GGML_LOG_ERROR("%s: failed to initialize ggml context for storing tensors\n", __func__); gguf_free(ctx); return nullptr; } @@ -656,7 +656,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par ok = ok && gr.read(data->data, ctx->size); if (!ok) { - fprintf(stderr, "%s: failed to read tensor data binary blob\n", __func__); + GGML_LOG_ERROR("%s: failed to read tensor data binary blob\n", __func__); ggml_free(ctx_data); *params.ctx = nullptr; gguf_free(ctx); @@ -689,7 +689,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par } if (!ok) { - fprintf(stderr, "%s: failed to create tensors\n", __func__); + GGML_LOG_ERROR("%s: failed to create tensors\n", __func__); ggml_free(ctx_data); *params.ctx = nullptr; gguf_free(ctx); @@ -706,7 +706,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p FILE * file = ggml_fopen(fname, "rb"); if (!file) { - fprintf(stderr, "%s: failed to open GGUF file '%s'\n", __func__, fname); + GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname); return nullptr; } @@ -1305,7 +1305,7 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo FILE * file = ggml_fopen(fname, "wb"); if (!file) { - fprintf(stderr, "%s: failed to open file '%s' for writing GGUF data\n", __func__, fname); + GGML_LOG_ERROR("%s: failed to open file '%s' for writing GGUF data\n", __func__, fname); return false; } diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 2ad2ea1c651..6dbde75a843 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -148b286332db1259dcd299c04047a1fd31b02713 +c6202093c3fb4ce8f728d86838400b35cc01ac7c