diff --git a/.clang-format b/.clang-format
index 47d96b6b40983..117e6986f6f8c 100644
--- a/.clang-format
+++ b/.clang-format
@@ -22,7 +22,7 @@ AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: Inline
AllowShortLoopsOnASingleLine: false
AlwaysBreakBeforeMultilineStrings: true
-BinPackArguments: false
+BinPackArguments: true
BinPackParameters: false # OnePerLine
BitFieldColonSpacing: Both
BreakBeforeBraces: Custom # Attach
diff --git a/README.md b/README.md
index a01ef6d503e40..17f59e988e3d1 100644
--- a/README.md
+++ b/README.md
@@ -137,6 +137,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
+- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
#### Multimodal
diff --git a/ci/run.sh b/ci/run.sh
index d51ba443859c2..a250393ee97ce 100755
--- a/ci/run.sh
+++ b/ci/run.sh
@@ -386,10 +386,10 @@ function gg_run_open_llama_7b_v2 {
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
- (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
- (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
- (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
- (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+ (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+ (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+ (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+ (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
function check_ppl {
qnt="$1"
@@ -520,8 +520,8 @@ function gg_run_pythia_1_4b {
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
- (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
- (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+ (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+ (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
function check_ppl {
qnt="$1"
@@ -651,10 +651,10 @@ function gg_run_pythia_2_8b {
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
- (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
- (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
- (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
- (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+ (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+ (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+ (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
+ (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
function check_ppl {
qnt="$1"
diff --git a/common/arg.cpp b/common/arg.cpp
index 81c4005c5e7fc..fcee0c4470077 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -1106,7 +1106,7 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
printf("\"\n\n");
printf(" case \"$prev\" in\n");
- printf(" --model)\n");
+ printf(" --model|-m)\n");
printf(" COMPREPLY=( $(compgen -f -X '!*.gguf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n");
printf(" return 0\n");
printf(" ;;\n");
@@ -1545,10 +1545,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL}));
add_opt(common_arg(
- {"-fa", "--flash-attn"},
- string_format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"),
- [](common_params & params) {
- params.flash_attn = true;
+ {"-fa", "--flash-attn"}, "FA",
+ string_format("set Flash Attention use ('on', 'off', or 'auto', default: '%s')", llama_flash_attn_type_name(params.flash_attn_type)),
+ [](common_params & params, const std::string & value) {
+ if (value == "on" || value == "enabled" || value == "1") {
+ params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
+ } else if (value == "off" || value == "disabled" || value == "0") {
+ params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
+ } else if (value == "auto" || value == "-1") {
+ params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
+ } else {
+ throw std::runtime_error(string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
+ }
}
).set_env("LLAMA_ARG_FLASH_ATTN"));
add_opt(common_arg(
@@ -2555,7 +2563,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--lora"}, "FNAME",
"path to LoRA adapter (can be repeated to use multiple adapters)",
[](common_params & params, const std::string & value) {
- params.lora_adapters.push_back({ std::string(value), 1.0, nullptr });
+ params.lora_adapters.push_back({ std::string(value), 1.0, "", "", nullptr });
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
@@ -2563,7 +2571,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--lora-scaled"}, "FNAME", "SCALE",
"path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
[](common_params & params, const std::string & fname, const std::string & scale) {
- params.lora_adapters.push_back({ fname, std::stof(scale), nullptr });
+ params.lora_adapters.push_back({ fname, std::stof(scale), "", "", nullptr });
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
@@ -2954,13 +2962,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.endpoint_metrics = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_METRICS"));
- add_opt(common_arg(
- {"--slots"},
- string_format("enable slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"),
- [](common_params & params) {
- params.endpoint_slots = true;
- }
- ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_SLOTS"));
add_opt(common_arg(
{"--props"},
string_format("enable changing global properties via POST /props (default: %s)", params.endpoint_props ? "enabled" : "disabled"),
@@ -2968,6 +2969,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.endpoint_props = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_PROPS"));
+ add_opt(common_arg(
+ {"--slots"},
+ string_format("enable slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"),
+ [](common_params & params) {
+ params.endpoint_slots = true;
+ }
+ ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_SLOTS"));
add_opt(common_arg(
{"--no-slots"},
"disables slots monitoring endpoint",
@@ -3459,8 +3467,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf";
params.port = 8012;
- params.n_gpu_layers = 99;
- params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
@@ -3475,8 +3481,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-3b-q8_0.gguf";
params.port = 8012;
- params.n_gpu_layers = 99;
- params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
@@ -3491,8 +3495,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
params.port = 8012;
- params.n_gpu_layers = 99;
- params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
@@ -3508,10 +3510,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
- params.speculative.n_gpu_layers = 99;
params.port = 8012;
- params.n_gpu_layers = 99;
- params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
@@ -3527,10 +3526,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
- params.speculative.n_gpu_layers = 99;
params.port = 8012;
- params.n_gpu_layers = 99;
- params.flash_attn = true;
+ params.n_ubatch = 1024;
+ params.n_batch = 1024;
+ params.n_ctx = 0;
+ params.n_cache_reuse = 256;
+ }
+ ).set_examples({LLAMA_EXAMPLE_SERVER}));
+
+ add_opt(common_arg(
+ {"--fim-qwen-30b-default"},
+ string_format("use default Qwen 3 Coder 30B A3B Instruct (note: can download weights from the internet)"),
+ [](common_params & params) {
+ params.model.hf_repo = "ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF";
+ params.model.hf_file = "qwen3-coder-30b-a3b-instruct-q8_0.gguf";
+ params.port = 8012;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
diff --git a/common/chat.cpp b/common/chat.cpp
index 111b4a21b368c..955c42852a9be 100644
--- a/common/chat.cpp
+++ b/common/chat.cpp
@@ -622,6 +622,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
+ case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
default:
throw std::runtime_error("Unknown chat format");
}
@@ -2059,6 +2060,94 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
}
}
+static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
+ // Parse thinking tags first - this handles the main reasoning content
+ builder.try_parse_reasoning("", "");
+
+ if (!builder.syntax().parse_tool_calls) {
+ builder.add_content(builder.consume_rest());
+ return;
+ }
+
+ // Parse tool calls - Seed-OSS uses format
+ static const common_regex tool_call_begin_regex("");
+ static const common_regex tool_call_end_regex("");
+ static const common_regex function_regex("]+)>");
+ static const common_regex param_regex("]+)>");
+
+ while (auto tool_res = builder.try_find_regex(tool_call_begin_regex)) {
+ builder.consume_spaces(); // Consume whitespace after
+
+ // Look for function call inside tool call, ignore any content before it
+ if (auto func_res = builder.try_find_regex(function_regex, std::string::npos, false)) {
+ auto function_name = builder.str(func_res->groups[1]);
+
+ // Parse Seed-OSS parameters value
+ json args = json::object();
+ // Parse all parameters
+ while (auto param_res = builder.try_find_regex(param_regex, std::string::npos, false)) {
+ // again, ignore noise around parameters
+ auto param_name = builder.str(param_res->groups[1]);
+ builder.move_to(param_res->groups[0].end);
+ builder.consume_spaces(); // Consume whitespace after parameter
+ auto savedPos = builder.pos();
+ if (auto param_parse = builder.try_find_literal("")) {
+ auto param = param_parse->prelude;
+ builder.move_to(savedPos);
+ try {
+ if (auto param_res = builder.try_consume_json()) {
+ args[param_name] = param_res->json;
+ } else {
+ args[param_name] = param;
+ }
+ } catch (json::exception &) {
+ args[param_name] = param;
+ }
+ } else {
+ throw common_chat_msg_partial_exception("Incomplete tool parameter");
+ }
+ }
+ // Look for closing function tag
+ auto end_func = builder.try_find_literal("");
+ if (end_func) {
+ builder.move_to(end_func->groups[0].end);
+ builder.consume_spaces(); // Consume whitespace after
+
+ // Add the tool call with parsed arguments, but only if we REALLY got the literal
+ auto eaten_fragment = builder.input().substr(end_func->groups[0].begin, end_func->groups[0].end);
+ auto funlen = std::string("").length();
+ if (eaten_fragment.length() >= funlen && eaten_fragment.substr(0, funlen) == std::string("")) {
+ if (!builder.add_tool_call(function_name, "", args.dump())) {
+ throw common_chat_msg_partial_exception("Incomplete tool call");
+ }
+ } else {
+ throw common_chat_msg_partial_exception("Incomplete tool call");
+ }
+ } else {
+ throw common_chat_msg_partial_exception("Incomplete tool call");
+ }
+ // Look for closing tool call tag
+ if (auto end_tool = builder.try_find_regex(tool_call_end_regex, std::string::npos, false)) {
+ builder.move_to(end_tool->groups[0].end);
+ builder.consume_spaces(); // Consume trailing whitespace after tool call
+ } else {
+ throw common_chat_msg_partial_exception("Incomplete tool call");
+ }
+ } else {
+ // No function found - don't consume content here, let it be handled at the end
+ break;
+ }
+ }
+
+ // Consume any remaining whitespace after all tool call processing
+ builder.consume_spaces();
+ auto remaining = builder.consume_rest();
+ // If there's any non-whitespace content remaining, add it as content
+ if (!string_strip(remaining).empty()) {
+ builder.add_content(remaining);
+ }
+}
+
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs);
@@ -2075,8 +2164,62 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
return data;
}
+static common_chat_params common_chat_params_init_seed_oss(
+ const common_chat_template & tmpl,
+ templates_params & params,
+ const common_chat_templates_inputs & inputs)
+{
+ common_chat_params data;
+ data.prompt = apply(tmpl, params);
+ data.format = COMMON_CHAT_FORMAT_SEED_OSS;
+ if (string_ends_with(data.prompt, "")) {
+ if (!inputs.enable_thinking) {
+ data.prompt += "";
+ } else {
+ data.thinking_forced_open = true;
+ }
+ }
+
+ if (params.tools.is_array() && !params.tools.empty()) {
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
+ std::vector tool_rules;
+ foreach_function(params.tools, [&](const json & tool) {
+ const auto & function = tool.at("function");
+ std::string name = function.at("name");
+ auto parameters = function.at("parameters");
+ builder.resolve_refs(parameters);
+
+ // Create rule for Seed-OSS function call format
+ std::string param_rules;
+ if (parameters.contains("properties")) {
+ for (const auto & [key, value] : parameters.at("properties").items()) {
+ param_rules += "\"\"" + builder.add_schema(name + "-arg-" + key, value) +
+ "\"\"";
+ }
+ }
+
+ tool_rules.push_back(builder.add_rule(name + "-call",
+ "\"\" space \"\" space " +
+ param_rules +
+ " \"\" space \"\""));
+ });
+
+ data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "" });
+
+ data.preserved_tokens = {
+ "", "", "", "",
+ "", "",
+ };
+
+ builder.add_rule("root", string_join(tool_rules, " | "));
+ });
+ }
+ return data;
+}
+
static common_chat_params common_chat_templates_apply_jinja(
- const struct common_chat_templates * tmpls,
+ const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
{
templates_params params;
@@ -2145,6 +2288,11 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_gpt_oss(tmpl, params);
}
+ // Seed-OSS
+ if (src.find("") != std::string::npos) {
+ return common_chat_params_init_seed_oss(tmpl, params, inputs);
+ }
+
// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) {
@@ -2303,6 +2451,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;
+ case COMMON_CHAT_FORMAT_SEED_OSS:
+ common_chat_parse_seed_oss(builder);
+ break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
diff --git a/common/chat.h b/common/chat.h
index d1e480c918794..b09ff3b126a2b 100644
--- a/common/chat.h
+++ b/common/chat.h
@@ -111,6 +111,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_GRANITE,
COMMON_CHAT_FORMAT_GPT_OSS,
+ COMMON_CHAT_FORMAT_SEED_OSS,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
diff --git a/common/common.cpp b/common/common.cpp
index fdce1dcdec19b..0c92d4d57ddbf 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -901,7 +901,8 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
if (model == NULL) {
- LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
+ LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
+ __func__, params.model.path.c_str());
return iparams;
}
@@ -911,7 +912,8 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
- LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
+ LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
+ __func__, params.model.path.c_str());
llama_model_free(model);
return iparams;
}
@@ -988,7 +990,12 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}
+ char buf[1024];
la.ptr = lora.get();
+ llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf));
+ la.task_name = buf;
+ llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
+ la.prompt_prefix = buf;
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
}
@@ -1152,10 +1159,10 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.pooling_type = params.pooling_type;
cparams.attention_type = params.attention_type;
+ cparams.flash_attn_type = params.flash_attn_type;
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;
- cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
cparams.swa_full = params.swa_full;
diff --git a/common/common.h b/common/common.h
index 390dda5e531be..85b3b879d4536 100644
--- a/common/common.h
+++ b/common/common.h
@@ -34,6 +34,9 @@ struct common_adapter_lora_info {
std::string path;
float scale;
+ std::string task_name;
+ std::string prompt_prefix;
+
struct llama_adapter_lora * ptr;
};
@@ -309,6 +312,7 @@ struct common_params {
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
+ enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention
struct common_params_sampling sampling;
struct common_params_speculative speculative;
@@ -372,7 +376,6 @@ struct common_params {
bool multiline_input = false; // reverse the usage of `\`
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
- bool flash_attn = false; // flash attention
bool no_perf = false; // disable performance metrics
bool ctx_shift = false; // context shift on infinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
@@ -441,7 +444,7 @@ struct common_params {
// "advanced" endpoints are disabled by default for better security
bool webui = true;
- bool endpoint_slots = false;
+ bool endpoint_slots = true;
bool endpoint_props = false; // only control POST requests, not GET
bool endpoint_metrics = false;
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 9c04d35fd00a2..c710ee173c0ed 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -426,8 +426,29 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
// helpers
-llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
- return &gsmpl->cur_p;
+llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
+ auto * res = &gsmpl->cur_p;
+
+ if (do_sort && !res->sorted) {
+ // remember the selected token before sorting
+ const llama_token id = res->data[res->selected].id;
+
+ std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) {
+ return a.p > b.p;
+ });
+
+ // restore the selected token after sorting
+ for (size_t i = 0; i < res->size; ++i) {
+ if (res->data[i].id == id) {
+ res->selected = i;
+ break;
+ }
+ }
+
+ res->sorted = true;
+ }
+
+ return res;
}
llama_token common_sampler_last(const struct common_sampler * gsmpl) {
diff --git a/common/sampling.h b/common/sampling.h
index 2064421db4e80..e198eecda3810 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -86,7 +86,9 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
// helpers
// access the internal list of current candidate tokens
-llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl);
+// if do_sort == true, the candidates are guaranteed to be sorted afterwards (in descending order of probability)
+// the .sorted flag of the result indicates whether the returned candidates are sorted
+llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort);
// get the last accepted token
llama_token common_sampler_last(const struct common_sampler * gsmpl);
diff --git a/common/speculative.cpp b/common/speculative.cpp
index 262b2c23e720f..3e83b0964c855 100644
--- a/common/speculative.cpp
+++ b/common/speculative.cpp
@@ -317,7 +317,7 @@ llama_tokens common_speculative_gen_draft(
common_sampler_sample(smpl, ctx_dft, 0, true);
- const auto * cur_p = common_sampler_get_candidates(smpl);
+ const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 35fadbc83ea1b..908435d7e5f00 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -72,6 +72,7 @@ class ModelBase:
endianess: gguf.GGUFEndian
use_temp_file: bool
lazy: bool
+ dry_run: bool
part_names: list[str]
is_safetensors: bool
hparams: dict[str, Any]
@@ -111,6 +112,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.use_temp_file = use_temp_file
self.lazy = not eager or (remote_hf_model_id is not None)
+ self.dry_run = dry_run
self.remote_hf_model_id = remote_hf_model_id
if remote_hf_model_id is not None:
self.is_safetensors = True
@@ -300,10 +302,6 @@ def prepare_tensors(self):
# data = data_torch.squeeze().numpy()
data = data_torch.numpy()
- # if data ends up empty, it means data_torch was a scalar tensor -> restore
- if len(data.shape) == 0:
- data = data_torch.numpy()
-
n_dims = len(data.shape)
data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)
@@ -1216,6 +1214,55 @@ def _try_set_pooling_type(self) -> None:
raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported")
self.gguf_writer.add_pooling_type(pooling_type)
+ def _set_vocab_interns1(self):
+ tokens: list[str] = []
+ toktypes: list[int] = []
+
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
+ vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
+ vocab_size = self.hparams.get("vocab_size", len(vocab))
+ assert max(vocab.values()) < vocab_size
+
+ tokpre = self.get_vocab_base_pre(tokenizer)
+
+ reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
+ added_vocab = tokenizer.get_added_vocab()
+
+ added_tokens_decoder = tokenizer.added_tokens_decoder
+
+ for i in range(vocab_size):
+ if i not in reverse_vocab:
+ tokens.append(f"[PAD{i}]")
+ toktypes.append(gguf.TokenType.UNUSED)
+ else:
+ token: str = reverse_vocab[i]
+ if token in added_vocab:
+ # The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
+ # To avoid unexpected issues - we make sure to normalize non-normalized tokens
+ if not added_tokens_decoder[i].normalized:
+ previous_token = token
+ token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
+ if previous_token != token:
+ logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
+
+ if added_tokens_decoder[i].special or self.does_token_look_special(token):
+ toktypes.append(gguf.TokenType.CONTROL)
+ else:
+ toktypes.append(gguf.TokenType.USER_DEFINED)
+ else:
+ toktypes.append(gguf.TokenType.NORMAL)
+ tokens.append(token)
+
+ self.gguf_writer.add_tokenizer_model("gpt2")
+ self.gguf_writer.add_tokenizer_pre(tokpre)
+ self.gguf_writer.add_token_list(tokens)
+ self.gguf_writer.add_token_types(toktypes)
+
+ special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
+ special_vocab._set_special_token("bos", 151643)
+ special_vocab.add_to_gguf(self.gguf_writer)
+
class MmprojModel(ModelBase):
model_type = ModelType.MMPROJ
@@ -2932,7 +2979,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
if "language_model." in name:
name = name.replace("language_model.", "") # for InternVL
if name.startswith("mlp") or name.startswith("multi_modal_projector") \
- or name.startswith("vision_model") or name.startswith("audio_tower"):
+ or name.startswith("vision_model") or name.startswith("audio_tower") \
+ or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
# skip vision and audio tensors
return []
yield from super().modify_tensors(data_torch, name, bid)
@@ -3109,7 +3157,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield from super().modify_tensors(data_torch, name, bid)
-@ModelBase.register("Ernie4_5_ForCausalLM")
+@ModelBase.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM")
class Ernie4_5Model(TextModel):
model_arch = gguf.MODEL_ARCH.ERNIE4_5
@@ -3604,6 +3652,19 @@ def prepare_tensors(self):
class Qwen3Model(Qwen2Model):
model_arch = gguf.MODEL_ARCH.QWEN3
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
+ self.origin_hf_arch = hparams.get('architectures', [None])[0]
+
+ def set_vocab(self):
+ # deal with intern-s1-mini
+ if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
+ self._set_vocab_interns1()
+ return
+
+ super().set_vocab()
+
@ModelBase.register("Qwen3MoeForCausalLM")
class Qwen3MoeModel(Qwen2MoeModel):
@@ -3620,73 +3681,7 @@ def set_vocab(self):
self._set_vocab_interns1()
return
- try:
- self._set_vocab_sentencepiece()
- except FileNotFoundError:
- self._set_vocab_gpt2()
-
- def _set_vocab_interns1(self):
- tokens: list[str] = []
- toktypes: list[int] = []
-
- from transformers import AutoTokenizer
- tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
- vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
- vocab_size = self.hparams.get("vocab_size", len(vocab))
- assert max(vocab.values()) < vocab_size
-
- tokpre = self.get_vocab_base_pre(tokenizer)
-
- reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
- added_vocab = tokenizer.get_added_vocab()
-
- added_tokens_decoder = tokenizer.added_tokens_decoder
-
- for i in range(vocab_size):
- if i not in reverse_vocab:
- tokens.append(f"[PAD{i}]")
- toktypes.append(gguf.TokenType.UNUSED)
- else:
- token: str = reverse_vocab[i]
- if token in added_vocab:
- # The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
- # To avoid unexpected issues - we make sure to normalize non-normalized tokens
- if not added_tokens_decoder[i].normalized:
- previous_token = token
- token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
- if previous_token != token:
- logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
-
- if added_tokens_decoder[i].special or self.does_token_look_special(token):
- toktypes.append(gguf.TokenType.CONTROL)
- else:
- toktypes.append(gguf.TokenType.USER_DEFINED)
- else:
- toktypes.append(gguf.TokenType.NORMAL)
- tokens.append(token)
-
- self.gguf_writer.add_tokenizer_model("gpt2")
- self.gguf_writer.add_tokenizer_pre(tokpre)
- self.gguf_writer.add_token_list(tokens)
- self.gguf_writer.add_token_types(toktypes)
-
- special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
- special_tokens_map_file = self.dir_model / 'special_tokens_map.json'
- additional_special_tokens = []
- if special_tokens_map_file.is_file():
- with open(special_tokens_map_file, encoding = 'utf-8') as f:
- additional_special_tokens = json.load(f).get('additional_special_tokens', [])
- tokenizer_cfg_file = self.dir_model / 'special_tokens_map.json'
- if tokenizer_cfg_file.is_file():
- with open(tokenizer_cfg_file, encoding = 'utf-8') as f:
- added_tokens_decoder = json.load(f).get('added_tokens_decoder', {})
- token2ids_map = {data['content'] : int(token) for token, data in added_tokens_decoder.items() if data['special']}
- for token in additional_special_tokens:
- if token in token2ids_map:
- special_vocab._set_special_token(token, token2ids_map[token])
- special_vocab._set_special_token('eos', 151645)
- special_vocab._set_special_token("bos", 151643)
- special_vocab.add_to_gguf(self.gguf_writer)
+ super().set_vocab()
@ModelBase.register("GPT2LMHeadModel")
@@ -4874,11 +4869,35 @@ def modify_tensors(self, data_torch, name, bid):
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
class XLMRobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT
+ _lora_files = {}
+ _lora_names = []
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
+ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
+ hparams = kwargs.pop("hparams", None)
+ if hparams is None:
+ hparams = ModelBase.load_hparams(dir_model, False)
+
+ if lora_names := hparams.get("lora_adaptations"):
+ self._lora_names = lora_names
+ self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3
+
+ super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
self._xlmroberta_tokenizer_init()
+ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
+ if self._lora_names:
+ for name in self._lora_names:
+ fname = self.add_prefix_to_filename(self.fname_out, f"lora-{name}-")
+ self._lora_files[name] = gguf.GGUFWriter(fname, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, dry_run=self.dry_run)
+
+ return super().generate_extra_tensors()
+
+ def set_type(self):
+ for lora_writer in self._lora_files.values():
+ lora_writer.add_type(gguf.GGUFType.ADAPTER)
+ lora_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
+ super().set_type()
+
def set_vocab(self):
self._xlmroberta_set_vocab()
@@ -4888,13 +4907,62 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
if name.startswith("roberta."):
name = name[8:]
+ # jina-embeddings-v3
+ if ".parametrizations." in name:
+ name = name.replace(".parametrizations.", ".")
+ if name.endswith(".original"):
+ name = name[:-9]
+
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
if name == "embeddings.position_embeddings.weight":
if self._position_offset is not None:
data_torch = data_torch[self._position_offset:,:]
+ if name.endswith(".0.lora_A") or name.endswith(".0.lora_B"):
+ if name.startswith("pooler.dense"):
+ return []
+
+ num_loras = data_torch.size(0)
+ assert num_loras == len(self._lora_names)
+
+ # Split out each LoRA in their own GGUF
+ for i, lora_writer in enumerate(self._lora_files.values()):
+ new_name = self.map_tensor_name(name[:-9]) + name[-7:].lower()
+ data = data_torch[i, :, :]
+ # Transpose/flip token_embd/types into correct shape
+ if new_name == "token_embd.weight.lora_b":
+ data = data.T
+ elif new_name.startswith("token_types.weight."):
+ new_name = new_name[:-1] + ("a" if new_name[-1:] == "b" else "b")
+ lora_writer.add_tensor(new_name, data.float().numpy(), raw_dtype=gguf.GGMLQuantizationType.F32)
+
+ return []
+
return super().modify_tensors(data_torch, name, bid)
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+
+ # jina-embeddings-v3
+ if rotary_emb_base := self.hparams.get("rotary_emb_base"):
+ self.gguf_writer.add_rope_freq_base(rotary_emb_base)
+ lora_alpha = self.hparams.get("lora_alpha")
+ if lora_prompt_prefixes := self.hparams.get("task_instructions"):
+ assert self._lora_files and all(lora_name in lora_prompt_prefixes for lora_name in self._lora_files.keys())
+ for lora_name, lora_writer in self._lora_files.items():
+ lora_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, lora_alpha if lora_alpha is not None else 1.0)
+ lora_writer.add_string(gguf.Keys.Adapter.LORA_TASK_NAME, lora_name)
+ if lora_prompt_prefixes:
+ lora_writer.add_string(gguf.Keys.Adapter.LORA_PROMPT_PREFIX, lora_prompt_prefixes[lora_name])
+
+ def write(self):
+ super().write()
+ for lora_writer in self._lora_files.values():
+ lora_writer.write_header_to_file()
+ lora_writer.write_kv_data_to_file()
+ lora_writer.write_tensors_to_file(progress=True)
+ lora_writer.close()
+
@ModelBase.register("GemmaForCausalLM")
class GemmaModel(TextModel):
@@ -6257,9 +6325,11 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")
-@ModelBase.register("DeepseekV2ForCausalLM")
-@ModelBase.register("DeepseekV3ForCausalLM")
-@ModelBase.register("KimiVLForConditionalGeneration")
+@ModelBase.register(
+ "DeepseekV2ForCausalLM",
+ "DeepseekV3ForCausalLM",
+ "KimiVLForConditionalGeneration",
+)
class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
@@ -7472,9 +7542,13 @@ def __init__(self, *args, **kwargs):
]
# n_group and d_inner are used during reshape_tensors for mamba2
- self.d_model = self.find_hparam(["hidden_size", "d_model"])
- self.n_group = self.find_hparam(["n_groups"])
- self.d_inner = self.find_hparam(["expand"]) * self.d_model
+ # NOTE: Explicitly include hparam prefix prefix for d_model to
+ # disambiguate with top-level head_dim
+ # NOTE 2: If needed for future models, this can be isolated in a method
+ # to separate the prefix setting and teh keys used
+ self.d_model = self.find_hparam([f"{self.hparam_prefixes[0]}_head_dim", "hidden_size", "d_model"])
+ self.n_group = self.find_hparam(["n_groups", "num_groups"])
+ self.d_inner = self.find_hparam(["expand", "num_heads"]) * self.d_model
def get_attn_layers(self):
# Explicit list of layer type names
@@ -7535,12 +7609,12 @@ def set_gguf_parameters(self):
## Mamba mixer params ##
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
- self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
+ self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state", "state_dim", "ssm_state_size"]))
self.gguf_writer.add_ssm_group_count(self.n_group)
self.gguf_writer.add_ssm_inner_size(self.d_inner)
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
# in llama.cpp
- self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
+ self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads", "num_heads"]))
## Attention params ##
head_count_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
@@ -7567,6 +7641,55 @@ def set_vocab(self):
Mamba2Model.set_vocab(self)
+@ModelBase.register("NemotronHForCausalLM")
+class NemotronHModel(GraniteHybridModel):
+ """Hybrid mamba2/attention model from NVIDIA"""
+ model_arch = gguf.MODEL_ARCH.NEMOTRON_H
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # Save the top-level head_dim for later
+ self.head_dim = self.hparams.get("head_dim", self.hparams.get("attention_head_dim"))
+ assert self.head_dim is not None, "Could not find the attention head dim in config"
+
+ # Don't use expand to calculate d_inner
+ self.d_inner = self.find_hparam(["num_heads"]) * self.d_model
+
+ # Update the ssm / attn / mlp layers
+ # M: Mamba2, *: Attention, -: MLP
+ hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
+ self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"]
+ self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"]
+
+ def get_attn_layers(self):
+ hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
+ assert len(hybrid_override_pattern) == self.block_count, "Mismatch between hybrid override and num_hidden_layers!"
+ return [i for i, val in enumerate(hybrid_override_pattern) if val == "*"]
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+
+ self.gguf_writer.add_key_length(self.head_dim)
+ self.gguf_writer.add_value_length(self.head_dim)
+
+ # Set feed_forward_length
+ # NOTE: This will trigger an override warning. This is preferrable to
+ # duplicating all the parent logic
+ n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
+ self.gguf_writer.add_feed_forward_length([
+ n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
+ ])
+
+ def set_vocab(self):
+ super().set_vocab()
+
+ # The tokenizer _does_ add a BOS token (via post_processor type
+ # TemplateProcessing) but does not set add_bos_token to true in the
+ # config, so we need to explicitly override it here.
+ self.gguf_writer.add_add_bos_token(True)
+
+
@ModelBase.register("BailingMoeForCausalLM")
class BailingMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.BAILINGMOE
@@ -8510,6 +8633,43 @@ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", "
return "mm.2.weight"
return super().map_tensor_name(name, try_suffixes)
+
+@ModelBase.register("KimiVLForConditionalGeneration")
+class KimiVLModel(MmprojModel):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ assert self.hparams_vision is not None
+ self.hparams_vision["image_size"] = 64 * 14 # for compatibility
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIVL)
+ self.gguf_writer.add_vision_use_gelu(True)
+ self.gguf_writer.add_vision_projector_scale_factor(2)
+ # eps is the same as pytorch's default value
+ assert self.hparams_vision is not None
+ self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-5))
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ del bid # unused
+ is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
+
+ if is_vision_tensor:
+ if "pos_emb.weight" in name:
+ data_torch = data_torch.view(data_torch.shape[0] * data_torch.shape[1], data_torch.shape[2])
+ elif "wqkv" in name:
+ split_dim = 0 if "weight" in name else -1
+ wq, wk, wv = data_torch.chunk(3, dim=split_dim)
+ return [
+ (self.map_tensor_name(name.replace("wqkv", "wq")), wq),
+ (self.map_tensor_name(name.replace("wqkv", "wk")), wk),
+ (self.map_tensor_name(name.replace("wqkv", "wv")), wv)
+ ]
+
+ return [(self.map_tensor_name(name), data_torch)]
+
+ return [] # skip other tensors
+
###### CONVERSION LOGIC ######
diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md
index 325e09bd380c0..2d8866e3bd495 100755
--- a/docs/backend/CANN.md
+++ b/docs/backend/CANN.md
@@ -314,3 +314,7 @@ Controls automatic cleanup of the memory pool. This option is only effective whe
Converting the matmul weight format from ND to NZ can significantly improve performance on the 310I DUO NPU.
+### GGML_CANN_DISABLE_ACL_GRAPH
+
+When this variable is set, ACL graph execution is disabled and operators are executed in an op-by-op (eager) mode.
+This mode is mainly intended for debugging or for cases where the overhead of graph construction and execution is not desirable.
diff --git a/docs/build.md b/docs/build.md
index b35a898ba9ba4..dcbcce7549ad2 100644
--- a/docs/build.md
+++ b/docs/build.md
@@ -59,8 +59,6 @@ cmake --build build --config Release
cmake --preset arm64-windows-llvm-release -D GGML_OPENMP=OFF
cmake --build build-arm64-windows-llvm-release
```
- Building for arm64 can also be done with the MSVC compiler with the build-arm64-windows-MSVC preset, or the standard CMake build instructions. However, note that the MSVC compiler does not support inline ARM assembly code, used e.g. for the accelerated Q4_0_N_M CPU kernels.
-
For building with ninja generator and clang compiler as default:
-set path:set LIB=C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\x64;C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Tools\MSVC\14.41.34120\lib\x64\uwp;C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\x64
```bash
diff --git a/docs/function-calling.md b/docs/function-calling.md
index 37eacaf3100c1..67cf785c7a95d 100644
--- a/docs/function-calling.md
+++ b/docs/function-calling.md
@@ -21,6 +21,8 @@ Function calling is supported for all models (see https://github.com/ggml-org/ll
- Use `--chat-template-file` to override the template when appropriate (see examples below)
- Generic support may consume more tokens and be less efficient than a model's native format.
+- Multiple/parallel tool calling is supported on some models but disabled by default, enable it by passing `"parallel_tool_calls": true` in the completion endpoint payload.
+
Show some common templates and which format handler they use
diff --git a/docs/multimodal/minicpmv4.0.md b/docs/multimodal/minicpmv4.0.md
index 65887d96019d3..d04cb338cecb5 100644
--- a/docs/multimodal/minicpmv4.0.md
+++ b/docs/multimodal/minicpmv4.0.md
@@ -6,7 +6,7 @@ Download [MiniCPM-V-4](https://huggingface.co/openbmb/MiniCPM-V-4) PyTorch model
### Build llama.cpp
-Readme modification time: 20250206
+Readme modification time: 20250731
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md)
diff --git a/docs/multimodal/minicpmv4.5.md b/docs/multimodal/minicpmv4.5.md
new file mode 100644
index 0000000000000..8fea5e611da90
--- /dev/null
+++ b/docs/multimodal/minicpmv4.5.md
@@ -0,0 +1,47 @@
+## MiniCPM-V 4.5
+
+### Prepare models and code
+
+Download [MiniCPM-V-4_5](https://huggingface.co/openbmb/MiniCPM-V-4_5) PyTorch model from huggingface to "MiniCPM-V-4_5" folder.
+
+
+### Build llama.cpp
+Readme modification time: 20250826
+
+If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md)
+
+Clone llama.cpp:
+```bash
+git clone https://github.com/ggerganov/llama.cpp
+cd llama.cpp
+```
+
+Build llama.cpp using `CMake`:
+```bash
+cmake -B build
+cmake --build build --config Release
+```
+
+
+### Usage of MiniCPM-V 4
+
+Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-V-4_5-gguf) by us)
+
+```bash
+python ./tools/mtmd/legacy-models/minicpmv-surgery.py -m ../MiniCPM-V-4_5
+python ./tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-V-4_5 --minicpmv-projector ../MiniCPM-V-4_5/minicpmv.projector --output-dir ../MiniCPM-V-4_5/ --minicpmv_version 6
+python ./convert_hf_to_gguf.py ../MiniCPM-V-4_5/model
+
+# quantize int4 version
+./build/bin/llama-quantize ../MiniCPM-V-4_5/model/ggml-model-f16.gguf ../MiniCPM-V-4_5/model/ggml-model-Q4_K_M.gguf Q4_K_M
+```
+
+
+Inference on Linux or Mac
+```bash
+# run in single-turn mode
+./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_5/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-4_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
+
+# run in conversation mode
+./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-4_5/mmproj-model-f16.gguf
+```
diff --git a/examples/diffusion/diffusion-cli.cpp b/examples/diffusion/diffusion-cli.cpp
index 8431dcea8fe2a..abf7fb357326b 100644
--- a/examples/diffusion/diffusion-cli.cpp
+++ b/examples/diffusion/diffusion-cli.cpp
@@ -564,7 +564,7 @@ int main(int argc, char ** argv) {
ctx_params.n_ctx = params.n_ctx;
ctx_params.n_batch = params.n_batch;
ctx_params.n_ubatch = params.n_ubatch;
- ctx_params.flash_attn = params.flash_attn;
+ ctx_params.flash_attn_type = params.flash_attn_type;
ctx_params.no_perf = params.no_perf;
ctx_params.type_k = params.cache_type_k;
ctx_params.type_v = params.cache_type_v;
diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp
index 61eefc7248a7e..d4ef751fbb63a 100644
--- a/examples/eval-callback/eval-callback.cpp
+++ b/examples/eval-callback/eval-callback.cpp
@@ -28,9 +28,40 @@ static std::string ggml_ne_string(const ggml_tensor * t) {
return str;
}
+static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
+ size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
+ float v;
+ if (type == GGML_TYPE_F16) {
+ v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
+ } else if (type == GGML_TYPE_F32) {
+ v = *(float *) &data[i];
+ } else if (type == GGML_TYPE_I64) {
+ v = (float) *(int64_t *) &data[i];
+ } else if (type == GGML_TYPE_I32) {
+ v = (float) *(int32_t *) &data[i];
+ } else if (type == GGML_TYPE_I16) {
+ v = (float) *(int16_t *) &data[i];
+ } else if (type == GGML_TYPE_I8) {
+ v = (float) *(int8_t *) &data[i];
+ } else {
+ GGML_ABORT("fatal error");
+ }
+ return v;
+}
+
static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
GGML_ASSERT(n > 0);
float sum = 0;
+ for (int64_t i3 = 0; i3 < ne[3]; i3++) {
+ for (int64_t i2 = 0; i2 < ne[2]; i2++) {
+ for (int64_t i1 = 0; i1 < ne[1]; i1++) {
+ for (int64_t i0 = 0; i0 < ne[0]; i0++) {
+ const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
+ sum += v;
+ }
+ }
+ }
+ }
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
LOG(" [\n");
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
@@ -50,25 +81,8 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
LOG("..., ");
i0 = ne[0] - n;
}
- size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
- float v;
- if (type == GGML_TYPE_F16) {
- v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
- } else if (type == GGML_TYPE_F32) {
- v = *(float *) &data[i];
- } else if (type == GGML_TYPE_I64) {
- v = (float) *(int64_t *) &data[i];
- } else if (type == GGML_TYPE_I32) {
- v = (float) *(int32_t *) &data[i];
- } else if (type == GGML_TYPE_I16) {
- v = (float) *(int16_t *) &data[i];
- } else if (type == GGML_TYPE_I8) {
- v = (float) *(int8_t *) &data[i];
- } else {
- GGML_ABORT("fatal error");
- }
+ const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
LOG("%12.4f", v);
- sum += v;
if (i0 < ne[0] - 1) LOG(", ");
}
LOG("],\n");
diff --git a/examples/model-conversion/Makefile b/examples/model-conversion/Makefile
index 27d95b4f2bf5e..ac7a4147297c5 100644
--- a/examples/model-conversion/Makefile
+++ b/examples/model-conversion/Makefile
@@ -1,4 +1,5 @@
-# Validation functions
+MAKEFLAGS += --no-print-directory
+
define validate_model_path
@if [ -z "$(MODEL_PATH)" ]; then \
echo "Error: MODEL_PATH must be provided either as:"; \
@@ -17,6 +18,13 @@ define validate_embedding_model_path
fi
endef
+define quantize_model
+ @CONVERTED_MODEL="$(1)" QUANTIZED_TYPE="$(QUANTIZED_TYPE)" \
+ TOKEN_EMBD_TYPE="$(TOKEN_EMBD_TYPE)" OUTPUT_TYPE="$(OUTPUT_TYPE)" \
+ ./scripts/utils/quantize.sh "$(1)" "$(QUANTIZED_TYPE)" "$(TOKEN_EMBD_TYPE)" "$(OUTPUT_TYPE)"
+ @echo "Export the quantized model path to $(2) variable in your environment"
+endef
+
###
### Casual Model targets/recipes
###
@@ -29,6 +37,20 @@ causal-convert-model:
METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
./scripts/causal/convert-model.sh
+causal-convert-mm-model-bf16: OUTTYPE=bf16
+causal-convert-mm-model-bf16: MM_OUTTYPE=f16
+causal-convert-mm-model-bf16: causal-convert-mm-model
+
+causal-convert-mm-model:
+ $(call validate_model_path,causal-convert-mm-model)
+ @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(MODEL_PATH)" \
+ METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
+ ./scripts/causal/convert-model.sh
+
+ @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(MM_OUTTYPE)" MODEL_PATH="$(MODEL_PATH)" \
+ METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
+ ./scripts/causal/convert-model.sh --mmproj
+
causal-run-original-model:
$(call validate_model_path,causal-run-original-model)
@MODEL_PATH="$(MODEL_PATH)" ./scripts/causal/run-org-model.py
@@ -41,7 +63,7 @@ causal-verify-logits: causal-run-original-model causal-run-converted-model
@MODEL_PATH="$(MODEL_PATH)" ./scripts/utils/check-nmse.py -m ${MODEL_PATH}
causal-run-original-embeddings:
- @./scripts/causal/run-casual-gen-embeddings-org.sh
+ @./scripts/causal/run-casual-gen-embeddings-org.py
causal-run-converted-embeddings:
@./scripts/causal/run-converted-model-embeddings-logits.sh
@@ -67,9 +89,15 @@ causal-quantize-Q8_0: causal-quantize-model
causal-quantize-Q4_0: QUANTIZED_TYPE = Q4_0
causal-quantize-Q4_0: causal-quantize-model
+# For Quantization Aware Trained (QAT) models in Q4_0 we explicitly set the
+# token embedding and output types to Q8_0 instead of the default Q6_K.
+causal-quantize-qat-Q4_0: QUANTIZED_TYPE = Q4_0
+causal-quantize-qat-Q4_0: TOKEN_EMBD_TYPE = Q8_0
+causal-quantize-qat-Q4_0: OUTPUT_TYPE = Q8_0
+causal-quantize-qat-Q4_0: causal-quantize-model
+
causal-quantize-model:
- @CONVERTED_MODEL="$(CONVERTED_MODEL)" QUANTIZED_TYPE="$(QUANTIZED_TYPE)" ./scripts/utils/quantize.sh ${CONVERTED_MODEL} ${QUANTIZED_TYPE}
- @echo "Export the quantized model path to QUANTIZED_MODEL variable in your environment"
+ $(call quantize_model,$(CONVERTED_MODEL),QUANTIZED_MODEL)
causal-run-quantized-model:
@QUANTIZED_MODEL="$(QUANTIZED_MODEL)" ./scripts/causal/run-converted-model.sh ${QUANTIZED_MODEL}
@@ -117,9 +145,15 @@ embedding-quantize-Q8_0: embedding-quantize-model
embedding-quantize-Q4_0: QUANTIZED_TYPE = Q4_0
embedding-quantize-Q4_0: embedding-quantize-model
+# For Quantization Aware Trained (QAT) models in Q4_0 we explicitly set the
+# token embedding and output types to Q8_0 instead of the default Q6_K.
+embedding-quantize-qat-Q4_0: QUANTIZED_TYPE = Q4_0
+embedding-quantize-qat-Q4_0: TOKEN_EMBD_TYPE = Q8_0
+embedding-quantize-qat-Q4_0: OUTPUT_TYPE = Q8_0
+embedding-quantize-qat-Q4_0: embedding-quantize-model
+
embedding-quantize-model:
- @./scripts/utils/quantize.sh ${CONVERTED_EMBEDDING_MODEL} ${QUANTIZED_TYPE}
- @echo "Export the quantized model path to QUANTIZED_EMBEDDING_MODEL variable in your environment"
+ $(call quantize_model,$(CONVERTED_EMBEDDING_MODEL),QUANTIZED_EMBEDDING_MODEL)
embedding-run-quantized-model:
@./scripts/embedding/run-converted-model.sh ${QUANTIZED_EMBEDDING_MODEL}
@@ -144,6 +178,15 @@ perplexity-run:
hf-create-model:
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}"
+hf-create-model-dry-run:
+ @./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -d
+
+hf-create-model-embedding:
+ @./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -e
+
+hf-create-model-embedding-dry-run:
+ @./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -e -d
+
hf-create-model-private:
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -p
diff --git a/examples/model-conversion/README.md b/examples/model-conversion/README.md
index c924a6be3cd26..5e5992d964bbd 100644
--- a/examples/model-conversion/README.md
+++ b/examples/model-conversion/README.md
@@ -137,6 +137,18 @@ Then the quantized model can be run using the following command:
(venv) $ make causal-run-quantized-model
```
+### Quantizing QAT (Quantization Aware Training) models
+When quantizing to `Q4_0`, the default data type for the token embedding weights
+will be `Q6_K`. For models that are going to be uploaded to ggml-org it is
+recommended to use `Q8_0` instead for the embeddings and output tensors.
+The reason is that although `Q6_K` is smaller in size, it requires more compute
+to unpack, which can hurt performance during output generation when the entire
+embedding matrix must be dequantized to compute vocabulary logits. `Q8_0`
+provides practically full quality with better computational efficiency.
+```console
+(venv) $ make causal-quantize-qat-Q4_0
+```
+
## Embedding Language Model Conversion
@@ -238,6 +250,18 @@ Then the quantized model can be run using the following command:
(venv) $ make embedding-run-quantized-model
```
+### Quantizing QAT (Quantization Aware Training) models
+When quantizing to `Q4_0`, the default data type for the token embedding weights
+will be `Q6_K`. For models that are going to be uploaded to ggml-org it is
+recommended to use `Q8_0` instead for the embeddings and output tensors.
+The reason is that although `Q6_K` is smaller in size, it requires more compute
+to unpack, which can hurt performance during output generation when the entire
+embedding matrix must be dequantized to compute vocabulary logits. `Q8_0`
+provides practically full quality with better computational efficiency.
+```console
+(venv) $ make embedding-quantize-qat-Q4_0
+```
+
## Perplexity Evaluation
### Simple perplexity evaluation
@@ -285,13 +309,21 @@ For the following targets a `HF_TOKEN` environment variable is required.
This will create a new model repsository on Hugging Face with the specified
model name.
```console
-(venv) $ make hf-create-model MODEL_NAME='TestModel' NAMESPACE="danbev"
+(venv) $ make hf-create-model MODEL_NAME='TestModel' NAMESPACE="danbev" ORIGINAL_BASE_MODEL="some-base-model"
Repository ID: danbev/TestModel-GGUF
Repository created: https://huggingface.co/danbev/TestModel-GGUF
```
Note that we append a `-GGUF` suffix to the model name to ensure a consistent
naming convention for GGUF models.
+An embedding model can be created using the following command:
+```console
+(venv) $ make hf-create-model-embedding MODEL_NAME='TestEmbeddingModel' NAMESPACE="danbev" ORIGINAL_BASE_MODEL="some-base-model"
+```
+The only difference is that the model card for an embedding model will be different
+with regards to the llama-server command and also how to access/call the embedding
+endpoint.
+
### Upload a GGUF model to model repository
The following target uploads a model to an existing Hugging Face model repository.
```console
diff --git a/examples/model-conversion/logits.cpp b/examples/model-conversion/logits.cpp
index 2cac6a3b3eaae..ddc5e9005f9e0 100644
--- a/examples/model-conversion/logits.cpp
+++ b/examples/model-conversion/logits.cpp
@@ -112,6 +112,7 @@ int main(int argc, char ** argv) {
ctx_params.no_perf = false;
if (embedding_mode) {
ctx_params.embeddings = true;
+ ctx_params.pooling_type = LLAMA_POOLING_TYPE_NONE;
ctx_params.n_ubatch = ctx_params.n_batch;
}
diff --git a/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh b/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh
index 287158f638717..c53c89d48acc6 100755
--- a/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh
+++ b/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh
@@ -1,4 +1,4 @@
-#/bin/bash
+#!/usr/bin/env bash
set -e
diff --git a/examples/model-conversion/scripts/causal/convert-model.sh b/examples/model-conversion/scripts/causal/convert-model.sh
index 56b21f9baa673..32ffe132e7853 100755
--- a/examples/model-conversion/scripts/causal/convert-model.sh
+++ b/examples/model-conversion/scripts/causal/convert-model.sh
@@ -1,4 +1,20 @@
-#!/bin/bash
+#!/usr/bin/env bash
+
+set -e
+
+# Parse command line arguments
+MMPROJ=""
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ --mmproj)
+ MMPROJ="--mmproj"
+ shift
+ ;;
+ *)
+ shift
+ ;;
+ esac
+done
MODEL_NAME="${MODEL_NAME:-$(basename "$MODEL_PATH")}"
OUTPUT_DIR="${OUTPUT_DIR:-../../models}"
@@ -11,12 +27,20 @@ echo "Model name: ${MODEL_NAME}"
echo "Data type: ${TYPE}"
echo "Converted model path:: ${CONVERTED_MODEL}"
echo "Metadata override: ${METADATA_OVERRIDE}"
-python ../../convert_hf_to_gguf.py --verbose \
- ${MODEL_PATH} \
- --outfile ${CONVERTED_MODEL} \
- --outtype ${TYPE} \
- --metadata "${METADATA_OVERRIDE}"
+
+CMD_ARGS=("python" "../../convert_hf_to_gguf.py" "--verbose")
+CMD_ARGS+=("${MODEL_PATH}")
+CMD_ARGS+=("--outfile" "${CONVERTED_MODEL}")
+CMD_ARGS+=("--outtype" "${TYPE}")
+[[ -n "$METADATA_OVERRIDE" ]] && CMD_ARGS+=("--metadata" "${METADATA_OVERRIDE}")
+[[ -n "$MMPROJ" ]] && CMD_ARGS+=("${MMPROJ}")
+
+"${CMD_ARGS[@]}"
echo ""
echo "The environment variable CONVERTED_MODEL can be set to this path using:"
echo "export CONVERTED_MODEL=$(realpath ${CONVERTED_MODEL})"
+if [[ -n "$MMPROJ" ]]; then
+ mmproj_file="${OUTPUT_DIR}/mmproj-$(basename "${CONVERTED_MODEL}")"
+ echo "The mmproj model was created in $(realpath "$mmproj_file")"
+fi
diff --git a/examples/model-conversion/scripts/readme.md.template b/examples/model-conversion/scripts/causal/modelcard.template
similarity index 100%
rename from examples/model-conversion/scripts/readme.md.template
rename to examples/model-conversion/scripts/causal/modelcard.template
diff --git a/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.sh b/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py
similarity index 100%
rename from examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.sh
rename to examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py
diff --git a/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh b/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh
index 64709f1798cb6..fa16a02c6599c 100755
--- a/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh
+++ b/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
set -e
diff --git a/examples/model-conversion/scripts/causal/run-converted-model.sh b/examples/model-conversion/scripts/causal/run-converted-model.sh
index e2762729e76a3..f5f567d4ffa12 100755
--- a/examples/model-conversion/scripts/causal/run-converted-model.sh
+++ b/examples/model-conversion/scripts/causal/run-converted-model.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
set -e
diff --git a/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh b/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh
index 35b5d71984c6e..1401dcb43ee92 100755
--- a/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh
+++ b/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh
@@ -1,4 +1,4 @@
-#/bin/bash
+#!/usr/bin/env bash
set -e
diff --git a/examples/model-conversion/scripts/embedding/convert-model.sh b/examples/model-conversion/scripts/embedding/convert-model.sh
index 0609e35357bd4..0929e42413e67 100755
--- a/examples/model-conversion/scripts/embedding/convert-model.sh
+++ b/examples/model-conversion/scripts/embedding/convert-model.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
set -e
diff --git a/examples/model-conversion/scripts/embedding/modelcard.template b/examples/model-conversion/scripts/embedding/modelcard.template
new file mode 100644
index 0000000000000..75c580524f667
--- /dev/null
+++ b/examples/model-conversion/scripts/embedding/modelcard.template
@@ -0,0 +1,48 @@
+---
+base_model:
+- {base_model}
+---
+# {model_name} GGUF
+
+Recommended way to run this model:
+
+```sh
+llama-server -hf {namespace}/{model_name}-GGUF
+```
+
+Then the endpoint can be accessed at http://localhost:8080/embedding, for
+example using `curl`:
+```console
+curl --request POST \
+ --url http://localhost:8080/embedding \
+ --header "Content-Type: application/json" \
+ --data '{{"input": "Hello embeddings"}}' \
+ --silent
+```
+
+Alternatively, the `llama-embedding` command line tool can be used:
+```sh
+llama-embedding -hf {namespace}/{model_name}-GGUF --verbose-prompt -p "Hello embeddings"
+```
+
+#### embd_normalize
+When a model uses pooling, or the pooling method is specified using `--pooling`,
+the normalization can be controlled by the `embd_normalize` parameter.
+
+The default value is `2` which means that the embeddings are normalized using
+the Euclidean norm (L2). Other options are:
+* -1 No normalization
+* 0 Max absolute
+* 1 Taxicab
+* 2 Euclidean/L2
+* \>2 P-Norm
+
+This can be passed in the request body to `llama-server`, for example:
+```sh
+ --data '{{"input": "Hello embeddings", "embd_normalize": -1}}' \
+```
+
+And for `llama-embedding`, by passing `--embd-normalize `, for example:
+```sh
+llama-embedding -hf {namespace}/{model_name}-GGUF --embd-normalize -1 -p "Hello embeddings"
+```
diff --git a/examples/model-conversion/scripts/embedding/run-converted-model.sh b/examples/model-conversion/scripts/embedding/run-converted-model.sh
index 5896090411a79..24b28106275df 100755
--- a/examples/model-conversion/scripts/embedding/run-converted-model.sh
+++ b/examples/model-conversion/scripts/embedding/run-converted-model.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
set -e
diff --git a/examples/model-conversion/scripts/utils/create-collection-add-model.sh b/examples/model-conversion/scripts/utils/create-collection-add-model.sh
index 4809da6cb6fac..485001b5feecc 100644
--- a/examples/model-conversion/scripts/utils/create-collection-add-model.sh
+++ b/examples/model-conversion/scripts/utils/create-collection-add-model.sh
@@ -1,4 +1,6 @@
+#!/usr/bin/env bash
+
COLLECTION_SLUG=$(python ./create_collection.py --return-slug)
echo "Created collection: $COLLECTION_SLUG"
diff --git a/examples/model-conversion/scripts/utils/curl-embedding-server.sh b/examples/model-conversion/scripts/utils/curl-embedding-server.sh
new file mode 100755
index 0000000000000..7ed69e1ea50f5
--- /dev/null
+++ b/examples/model-conversion/scripts/utils/curl-embedding-server.sh
@@ -0,0 +1,6 @@
+#!/usr/bin/env bash
+curl --request POST \
+ --url http://localhost:8080/embedding \
+ --header "Content-Type: application/json" \
+ --data '{"input": "Hello world today"}' \
+ --silent
diff --git a/examples/model-conversion/scripts/utils/hf-create-model.py b/examples/model-conversion/scripts/utils/hf-create-model.py
index 09bb8511ef13e..ea99bd886f4d1 100755
--- a/examples/model-conversion/scripts/utils/hf-create-model.py
+++ b/examples/model-conversion/scripts/utils/hf-create-model.py
@@ -26,21 +26,31 @@ def load_template_and_substitute(template_path, **kwargs):
parser.add_argument('--org-base-model', '-b', help='Original Base model name', default="")
parser.add_argument('--no-card', action='store_true', help='Skip creating model card')
parser.add_argument('--private', '-p', action='store_true', help='Create private model')
+parser.add_argument('--embedding', '-e', action='store_true', help='Use embedding model card template')
+parser.add_argument('--dry-run', '-d', action='store_true', help='Print repository info and template without creating repository')
args = parser.parse_args()
repo_id = f"{args.namespace}/{args.model_name}-GGUF"
print("Repository ID: ", repo_id)
-repo_url = api.create_repo(
- repo_id=repo_id,
- repo_type="model",
- private=args.private,
- exist_ok=False
-)
+repo_url = None
+if not args.dry_run:
+ repo_url = api.create_repo(
+ repo_id=repo_id,
+ repo_type="model",
+ private=args.private,
+ exist_ok=False
+ )
if not args.no_card:
- template_path = "scripts/readme.md.template"
+ if args.embedding:
+ template_path = "scripts/embedding/modelcard.template"
+ else:
+ template_path = "scripts/causal/modelcard.template"
+
+ print("Template path: ", template_path)
+
model_card_content = load_template_and_substitute(
template_path,
model_name=args.model_name,
@@ -48,16 +58,21 @@ def load_template_and_substitute(template_path, **kwargs):
base_model=args.org_base_model,
)
- if model_card_content:
- api.upload_file(
- path_or_fileobj=model_card_content.encode('utf-8'),
- path_in_repo="README.md",
- repo_id=repo_id
- )
- print("Model card created successfully.")
+ if args.dry_run:
+ print("\nTemplate Content:\n")
+ print(model_card_content)
else:
- print("Failed to create model card.")
+ if model_card_content:
+ api.upload_file(
+ path_or_fileobj=model_card_content.encode('utf-8'),
+ path_in_repo="README.md",
+ repo_id=repo_id
+ )
+ print("Model card created successfully.")
+ else:
+ print("Failed to create model card.")
-print(f"Repository created: {repo_url}")
+if not args.dry_run and repo_url:
+ print(f"Repository created: {repo_url}")
diff --git a/examples/model-conversion/scripts/utils/inspect-converted-model.sh b/examples/model-conversion/scripts/utils/inspect-converted-model.sh
index e5b9324542e73..32d84826fa089 100755
--- a/examples/model-conversion/scripts/utils/inspect-converted-model.sh
+++ b/examples/model-conversion/scripts/utils/inspect-converted-model.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
# First try command line argument, then environment variable, then file
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
diff --git a/examples/model-conversion/scripts/utils/perplexity-gen.sh b/examples/model-conversion/scripts/utils/perplexity-gen.sh
index 3db0b3fd27ac3..4885acbae24d1 100755
--- a/examples/model-conversion/scripts/utils/perplexity-gen.sh
+++ b/examples/model-conversion/scripts/utils/perplexity-gen.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
set -e
diff --git a/examples/model-conversion/scripts/utils/perplexity-run-simple.sh b/examples/model-conversion/scripts/utils/perplexity-run-simple.sh
index 69b3438f598b3..a2545436a5c52 100755
--- a/examples/model-conversion/scripts/utils/perplexity-run-simple.sh
+++ b/examples/model-conversion/scripts/utils/perplexity-run-simple.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
set -e
diff --git a/examples/model-conversion/scripts/utils/perplexity-run.sh b/examples/model-conversion/scripts/utils/perplexity-run.sh
index 3bce7c8472153..68b38e662859b 100755
--- a/examples/model-conversion/scripts/utils/perplexity-run.sh
+++ b/examples/model-conversion/scripts/utils/perplexity-run.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
set -e
diff --git a/examples/model-conversion/scripts/utils/quantize.sh b/examples/model-conversion/scripts/utils/quantize.sh
index bcb87757543cb..c25c5c21f3c3e 100755
--- a/examples/model-conversion/scripts/utils/quantize.sh
+++ b/examples/model-conversion/scripts/utils/quantize.sh
@@ -1,9 +1,11 @@
-#!/bin/bash
+#!/usr/bin/env bash
set -e
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
QUANTIZED_TYPE="${2:-"$QUANTIZED_TYPE"}"
+TOKEN_EMBD_TYPE="${3:-"${TOKEN_EMBD_TYPE}"}"
+OUTPUT_TYPE="${4:-"${OUTPUT_TYPE}"}"
QUANTIZED_MODEL=$CONVERTED_MODEL
# Final check if we have a model path
@@ -14,6 +16,11 @@ if [ -z "$CONVERTED_MODEL" ]; then
exit 1
fi
+if [ -z "$QUANTIZED_TYPE" ]; then
+ echo "Error: QUANTIZED_TYPE is required" >&2
+ exit 1
+fi
+
echo $CONVERTED_MODEL
# Process the quantized model filename
@@ -26,9 +33,16 @@ else
exit 1
fi
-
cmake --build ../../build --target llama-quantize -j8
-../../build/bin/llama-quantize $CONVERTED_MODEL $QUANTIZED_MODEL $QUANTIZED_TYPE
+echo $TOKEN_EMBD_TYPE
+echo $OUTPUT_TYPE
+
+CMD_ARGS=("../../build/bin/llama-quantize")
+[[ -n "$TOKEN_EMBD_TYPE" ]] && CMD_ARGS+=("--token-embedding-type" "$TOKEN_EMBD_TYPE")
+[[ -n "$OUTPUT_TYPE" ]] && CMD_ARGS+=("--output-tensor-type" "$OUTPUT_TYPE")
+CMD_ARGS+=("$CONVERTED_MODEL" "$QUANTIZED_MODEL" "$QUANTIZED_TYPE")
+
+"${CMD_ARGS[@]}"
echo "Quantized model saved to: $QUANTIZED_MODEL"
diff --git a/examples/model-conversion/scripts/utils/run-embedding-server.sh b/examples/model-conversion/scripts/utils/run-embedding-server.sh
index 828fc47069772..d30b765964b0c 100755
--- a/examples/model-conversion/scripts/utils/run-embedding-server.sh
+++ b/examples/model-conversion/scripts/utils/run-embedding-server.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
set -e
#
diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp
index 8449406a6d27a..5f5ac5eb64d38 100644
--- a/examples/speculative/speculative.cpp
+++ b/examples/speculative/speculative.cpp
@@ -244,7 +244,7 @@ int main(int argc, char ** argv) {
// stochastic verification
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
- auto & dist_tgt = *common_sampler_get_candidates(smpl);
+ auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
float p_tgt = 0.0f;
float p_dft = 0.0f;
@@ -493,7 +493,7 @@ int main(int argc, char ** argv) {
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
- const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl);
+ const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt
index 2ead001e2c610..9ef88c6fd0a85 100644
--- a/ggml/CMakeLists.txt
+++ b/ggml/CMakeLists.txt
@@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
-project("ggml" C CXX)
+project("ggml" C CXX ASM)
include(CheckIncludeFileCXX)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
@@ -129,7 +129,9 @@ endif()
option(GGML_LASX "ggml: enable lasx" ON)
option(GGML_LSX "ggml: enable lsx" ON)
option(GGML_RVV "ggml: enable rvv" ON)
-option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF)
+option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
+option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
+option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
option(GGML_VXE "ggml: enable vxe" ON)
option(GGML_NNPA "ggml: enable nnpa" OFF) # temp disabled by default, see: https://github.com/ggml-org/llama.cpp/issues/14877
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index a2977ea2e56d9..4f246f6ccd629 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -307,6 +307,9 @@ extern "C" {
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
+ // Split graph without allocating it
+ GGML_API void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+
// Allocate and compute graph on the backend scheduler
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index e34feccc98a5e..f615ab4beecb1 100644
--- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp
@@ -31,6 +31,7 @@
// backend buffer type
const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) {
+ GGML_ASSERT(buft);
return buft->iface.get_name(buft);
}
@@ -40,14 +41,17 @@ ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t
return ggml_backend_buffer_init(buft, {}, NULL, 0);
}
+ GGML_ASSERT(buft);
return buft->iface.alloc_buffer(buft, size);
}
size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) {
+ GGML_ASSERT(buft);
return buft->iface.get_alignment(buft);
}
size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
+ GGML_ASSERT(buft);
// get_max_size is optional, defaults to SIZE_MAX
if (buft->iface.get_max_size) {
return buft->iface.get_max_size(buft);
@@ -56,6 +60,7 @@ size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
}
size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
+ GGML_ASSERT(buft);
// get_alloc_size is optional, defaults to ggml_nbytes
if (buft->iface.get_alloc_size) {
size_t size = buft->iface.get_alloc_size(buft, tensor);
@@ -66,6 +71,7 @@ size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const s
}
bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) {
+ GGML_ASSERT(buft);
if (buft->iface.is_host) {
return buft->iface.is_host(buft);
}
@@ -73,6 +79,7 @@ bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) {
}
ggml_backend_dev_t ggml_backend_buft_get_device(ggml_backend_buffer_type_t buft) {
+ GGML_ASSERT(buft);
return buft->device;
}
@@ -110,10 +117,12 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
}
size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
return buffer->size;
}
void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
// get_base is optional if the buffer is zero-sized
if (buffer->size == 0) {
return NULL;
@@ -127,6 +136,7 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
}
enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+ GGML_ASSERT(buffer);
// init_tensor is optional
if (buffer->iface.init_tensor) {
return buffer->iface.init_tensor(buffer, tensor);
@@ -135,6 +145,7 @@ enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, s
}
void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ GGML_ASSERT(buffer);
// clear is optional if the buffer is zero-sized
if (buffer->size == 0) {
return;
@@ -160,6 +171,7 @@ bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) {
}
void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
+ GGML_ASSERT(buffer);
buffer->usage = usage;
// FIXME: add a generic callback to the buffer interface
@@ -169,14 +181,17 @@ void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backe
}
enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
return buffer->usage;
}
ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
return buffer->buft;
}
void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
if (buffer->iface.reset) {
buffer->iface.reset(buffer);
}
@@ -215,6 +230,7 @@ void ggml_backend_free(ggml_backend_t backend) {
}
ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {
+ GGML_ASSERT(backend);
return ggml_backend_dev_buffer_type(backend->device);
}
@@ -231,6 +247,8 @@ size_t ggml_backend_get_max_size(ggml_backend_t backend) {
}
void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ GGML_ASSERT(backend);
+ GGML_ASSERT(tensor);
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
@@ -242,6 +260,8 @@ void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor *
}
void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ GGML_ASSERT(backend);
+ GGML_ASSERT(tensor);
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
@@ -283,6 +303,7 @@ void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, siz
}
void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+ GGML_ASSERT(tensor);
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
if (size == 0) {
@@ -298,6 +319,7 @@ void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size
}
void ggml_backend_synchronize(ggml_backend_t backend) {
+ GGML_ASSERT(backend);
if (backend->iface.synchronize == NULL) {
return;
}
@@ -306,18 +328,21 @@ void ggml_backend_synchronize(ggml_backend_t backend) {
}
ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+ GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.graph_plan_create != NULL);
return backend->iface.graph_plan_create(backend, cgraph);
}
void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+ GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.graph_plan_free != NULL);
backend->iface.graph_plan_free(backend, plan);
}
enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+ GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.graph_plan_compute != NULL);
return backend->iface.graph_plan_compute(backend, plan);
@@ -330,22 +355,27 @@ enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_
}
enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+ GGML_ASSERT(backend);
return backend->iface.graph_compute(backend, cgraph);
}
bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+ GGML_ASSERT(backend);
return ggml_backend_dev_supports_op(backend->device, op);
}
bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+ GGML_ASSERT(backend);
return ggml_backend_dev_supports_buft(backend->device, buft);
}
bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+ GGML_ASSERT(backend);
return ggml_backend_dev_offload_op(backend->device, op);
}
ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
+ GGML_ASSERT(backend);
return backend->device;
}
@@ -381,6 +411,7 @@ void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t b
return;
}
+ GGML_ASSERT(backend_dst);
if (backend_dst->iface.cpy_tensor_async != NULL) {
if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) {
return;
@@ -412,18 +443,21 @@ void ggml_backend_event_free(ggml_backend_event_t event) {
}
void ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend) {
+ GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.event_record != NULL);
backend->iface.event_record(backend, event);
}
void ggml_backend_event_synchronize(ggml_backend_event_t event) {
+ GGML_ASSERT(event);
GGML_ASSERT(event->device->iface.event_synchronize);
event->device->iface.event_synchronize(event->device, event);
}
void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
+ GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.event_wait != NULL);
backend->iface.event_wait(backend, event);
@@ -432,18 +466,22 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
// Backend device
const char * ggml_backend_dev_name(ggml_backend_dev_t device) {
+ GGML_ASSERT(device);
return device->iface.get_name(device);
}
const char * ggml_backend_dev_description(ggml_backend_dev_t device) {
+ GGML_ASSERT(device);
return device->iface.get_description(device);
}
void ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
+ GGML_ASSERT(device);
device->iface.get_memory(device, free, total);
}
enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) {
+ GGML_ASSERT(device);
return device->iface.get_type(device);
}
@@ -453,18 +491,22 @@ void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_d
}
ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device) {
+ GGML_ASSERT(device);
return device->reg;
}
ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params) {
+ GGML_ASSERT(device);
return device->iface.init_backend(device, params);
}
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
+ GGML_ASSERT(device);
return device->iface.get_buffer_type(device);
}
ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) {
+ GGML_ASSERT(device);
if (device->iface.get_host_buffer_type == NULL) {
return NULL;
}
@@ -473,18 +515,22 @@ ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t
}
ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size) {
+ GGML_ASSERT(device);
return device->iface.buffer_from_host_ptr(device, ptr, size, max_tensor_size);
}
bool ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op) {
+ GGML_ASSERT(device);
return device->iface.supports_op(device, op);
}
bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft) {
+ GGML_ASSERT(device);
return device->iface.supports_buft(device, buft);
}
bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op) {
+ GGML_ASSERT(device);
if (device->iface.offload_op != NULL) {
return device->iface.offload_op(device, op);
}
@@ -495,18 +541,22 @@ bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_te
// Backend (reg)
const char * ggml_backend_reg_name(ggml_backend_reg_t reg) {
+ GGML_ASSERT(reg);
return reg->iface.get_name(reg);
}
size_t ggml_backend_reg_dev_count(ggml_backend_reg_t reg) {
+ GGML_ASSERT(reg);
return reg->iface.get_device_count(reg);
}
ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index) {
+ GGML_ASSERT(reg);
return reg->iface.get_device(reg, index);
}
void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+ GGML_ASSERT(reg);
if (!reg->iface.get_proc_address) {
return NULL;
}
@@ -521,6 +571,7 @@ struct ggml_backend_multi_buffer_context {
};
static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
for (size_t i = 0; i < ctx->n_buffers; i++) {
ggml_backend_buffer_free(ctx->buffers[i]);
@@ -531,6 +582,7 @@ static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer)
}
static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ GGML_ASSERT(buffer);
ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
for (size_t i = 0; i < ctx->n_buffers; i++) {
ggml_backend_buffer_clear(ctx->buffers[i], value);
@@ -566,10 +618,12 @@ ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer
}
bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
return buffer->iface.free_buffer == ggml_backend_multi_buffer_free_buffer;
}
void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
+ GGML_ASSERT(buffer);
GGML_ASSERT(ggml_backend_buffer_is_multi_buffer(buffer));
ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
for (size_t i = 0; i < ctx->n_buffers; i++) {
@@ -597,7 +651,7 @@ static bool ggml_is_view_op(enum ggml_op op) {
#endif
#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
-#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC
+#define GGML_SCHED_MAX_SPLIT_INPUTS 30
#endif
#ifndef GGML_SCHED_MAX_COPIES
@@ -848,7 +902,7 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru
}
// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
-static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
// reset splits
sched->n_splits = 0;
sched->n_graph_inputs = 0;
@@ -1349,6 +1403,7 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
}
static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
+ GGML_ASSERT(sched);
struct ggml_backend_sched_split * splits = sched->splits;
ggml_tensor * prev_ids_tensor = nullptr;
@@ -1617,6 +1672,7 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
}
void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
+ GGML_ASSERT(sched);
// reset state for the next run
if (!sched->is_reset) {
ggml_hash_set_reset(&sched->hash_set);
@@ -1628,8 +1684,11 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
}
bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
+ GGML_ASSERT(sched);
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
+ ggml_backend_sched_reset(sched);
+
ggml_backend_sched_synchronize(sched);
ggml_backend_sched_split_graph(sched, measure_graph);
@@ -1644,6 +1703,7 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
}
bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+ GGML_ASSERT(sched);
GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs);
GGML_ASSERT(!sched->is_alloc);
@@ -1668,6 +1728,7 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st
}
enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+ GGML_ASSERT(sched);
if (!sched->is_reset && !sched->is_alloc) {
ggml_backend_sched_reset(sched);
}
@@ -1682,6 +1743,7 @@ enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sch
}
void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
+ GGML_ASSERT(sched);
for (int i = 0; i < sched->n_backends; i++) {
ggml_backend_synchronize(sched->backends[i]);
}
@@ -1694,28 +1756,34 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
}
void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
+ GGML_ASSERT(sched);
sched->callback_eval = callback;
sched->callback_eval_user_data = user_data;
}
int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
+ GGML_ASSERT(sched);
return sched->n_splits;
}
int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) {
+ GGML_ASSERT(sched);
return sched->n_copies;
}
int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) {
+ GGML_ASSERT(sched);
return sched->n_backends;
}
ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) {
+ GGML_ASSERT(sched);
GGML_ASSERT(i >= 0 && i < sched->n_backends);
return sched->backends[i];
}
size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
+ GGML_ASSERT(sched);
int backend_index = ggml_backend_sched_backend_id(sched, backend);
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
@@ -1723,6 +1791,7 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe
}
void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
+ GGML_ASSERT(sched);
int backend_index = ggml_backend_sched_backend_id(sched, backend);
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
tensor_backend_id(node) = backend_index;
@@ -1731,6 +1800,7 @@ void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct gg
}
ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) {
+ GGML_ASSERT(sched);
int backend_index = tensor_backend_id(node);
if (backend_index == -1) {
return NULL;
@@ -1741,6 +1811,7 @@ ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched,
// utils
enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) {
+ GGML_ASSERT(tensor);
GGML_ASSERT(tensor->buffer == NULL);
GGML_ASSERT(tensor->view_src != NULL);
GGML_ASSERT(tensor->view_src->buffer != NULL);
@@ -1752,6 +1823,7 @@ enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) {
}
enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
+ GGML_ASSERT(tensor);
GGML_ASSERT(tensor->buffer == NULL);
GGML_ASSERT(tensor->data == NULL);
GGML_ASSERT(tensor->view_src == NULL);
@@ -1825,6 +1897,7 @@ static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_
}
struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {
+ GGML_ASSERT(graph);
struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size);
struct ggml_tensor ** node_copies = (ggml_tensor **) calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT
bool * node_init = (bool *) calloc(hash_set.size, sizeof(node_init[0]));
@@ -1969,6 +2042,7 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
// CPU backend - buffer
static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
uintptr_t data = (uintptr_t)buffer->context;
// align the buffer
@@ -1980,28 +2054,33 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
}
static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ GGML_ASSERT(buffer);
ggml_aligned_free(buffer->context, buffer->size);
}
static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+ GGML_ASSERT(tensor);
memset((char *)tensor->data + offset, value, size);
GGML_UNUSED(buffer);
}
static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ GGML_ASSERT(tensor);
memcpy((char *)tensor->data + offset, data, size);
GGML_UNUSED(buffer);
}
static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ GGML_ASSERT(tensor);
memcpy(data, (const char *)tensor->data + offset, size);
GGML_UNUSED(buffer);
}
static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
+ GGML_ASSERT(src);
if (ggml_backend_buffer_is_host(src->buffer)) {
memcpy(dst->data, src->data, ggml_nbytes(src));
return true;
@@ -2012,6 +2091,7 @@ static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
}
static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ GGML_ASSERT(buffer);
memset(buffer->context, value, buffer->size);
}
diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp
index 8f65904b8fe51..80b9a932db988 100755
--- a/ggml/src/ggml-cann/aclnn_ops.cpp
+++ b/ggml/src/ggml-cann/aclnn_ops.cpp
@@ -70,6 +70,8 @@
#include
#include
#include
+#include
+#include
#include
#include
@@ -964,8 +966,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
}
aclTensor* acl_gamma = get_f32_cache_acl_tensor(
ctx,
- &ctx.f32_one_cache,
- ctx.f32_one_cache_element,
+ &ctx.rms_norm_one_tensor_cache.cache,
+ ctx.rms_norm_one_tensor_cache.size,
src->ne,
acl_gamma_nb,
1, // dims
@@ -980,8 +982,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
}
aclTensor* acl_rstd = get_f32_cache_acl_tensor(
ctx,
- &ctx.f32_zero_cache,
- ctx.f32_zero_cache_element,
+ &ctx.rms_norm_zero_tensor_cache.cache,
+ ctx.rms_norm_zero_tensor_cache.size,
src->ne,
acl_rstd_nb,
GGML_MAX_DIMS,
@@ -1257,12 +1259,20 @@ static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) {
void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_dst) {
- GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
+ if(acl_dst == nullptr) {
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCos, acl_src);
+ } else {
+ GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
+ }
}
void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_dst) {
- GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
+ if(acl_dst == nullptr) {
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSin, acl_src);
+ } else {
+ GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
+ }
}
void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
@@ -1415,21 +1425,25 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
* @param start Starting exponent offset.
* @param stop Stopping exponent offset (exclusive).
* @param step Step size for the exponent increment.
+ * @param dtype Data type for slope tensor.
*/
static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer,
- float m, int64_t size, float start, float stop, float step){
+ float m, int64_t size, float start, float stop, float step, ggml_type dtype){
+ aclDataType acl_type = ggml_cann_type_mapping(dtype);
+ size_t type_size = ggml_type_size(dtype);
+
int64_t ne[] = {size};
- size_t nb[] = {sizeof(float)};
+ size_t nb[] = {type_size};
- ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float));
+ ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * type_size);
void* arange_buffer = arange_allocator.get();
aclTensor* arange_tensor = ggml_cann_create_tensor(
- arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
+ arange_buffer, acl_type, type_size, ne, nb, 1);
aclnn_arange(ctx, arange_tensor, start, stop, step, size);
aclTensor* slope_tensor = ggml_cann_create_tensor(
- slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
+ slope_buffer, acl_type, type_size, ne, nb, 1);
aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT);
@@ -1460,10 +1474,11 @@ static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_bu
* @param n_head Total number of attention heads.
* @param slope_buffer Pointer to the output buffer (float array) for storing slopes.
* @param max_bias Maximum bias value for slope computation.
+ * @param dtype Data type for slope tensor.
*
*/
static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
- void* slope_buffer, float max_bias) {
+ void* slope_buffer, float max_bias, ggml_type dtype) {
const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
@@ -1480,7 +1495,7 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
float step = 1;
float count = n_head_log2;
// end needs to be +1 because aclnn uses a left-closed, right-open interval.
- aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step);
+ aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step, dtype);
if (n_head_log2 < n_head) {
// arange2
start = 2 * (n_head_log2 - n_head_log2) + 1;
@@ -1489,7 +1504,7 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
count = n_head - n_head_log2;
aclnn_get_slope_inner(
ctx, (char *) slope_buffer + n_head_log2 * sizeof(float),
- m1, count, start, end + 1, step);
+ m1, count, start, end + 1, step, dtype);
}
}
@@ -1526,7 +1541,7 @@ static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask,
ggml_cann_pool_alloc bias_allocator(
ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst));
bias_buffer = bias_allocator.get();
- aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias);
+ aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias, GGML_TYPE_F32);
}
// broadcast for mask, slop and dst;
@@ -1752,10 +1767,10 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
case GGML_TYPE_F16: {
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
ggml_cann_pool_alloc src_buffer_allocator(
- ctx.pool(), ggml_nelements(src0) * sizeof(float_t));
+ ctx.pool(), ggml_nelements(src0) * sizeof(float));
void* src_trans_buffer = src_buffer_allocator.get();
size_t src_trans_nb[GGML_MAX_DIMS];
- src_trans_nb[0] = sizeof(float_t);
+ src_trans_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
}
@@ -1799,14 +1814,14 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// [3,4,5,64] -> [3,4,5,2,32]
dequant_ne = weight_ne;
- dequant_nb[0] = sizeof(float_t);
+ dequant_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {
dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];
}
scale_offset = ggml_nelements(src0) * sizeof(int8_t);
ggml_cann_pool_alloc dequant_buffer_allocator(
- ctx.pool(), ggml_nelements(src0) * sizeof(float_t));
+ ctx.pool(), ggml_nelements(src0) * sizeof(float));
aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb,
@@ -1815,11 +1830,11 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
aclTensor* dequant_tensor = ggml_cann_create_tensor(
- dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float_t),
+ dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float),
dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
- dequant_nb[0] = sizeof(float_t);
+ dequant_nb[0] = sizeof(float);
dequant_ne = src0->ne;
for (int i = 1; i < GGML_MAX_DIMS; i++) {
dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
@@ -2221,9 +2236,41 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
ggml_cann_release_resources(ctx, acl_index, acl_value);
}
+/**
+ * @brief Initializes and caches sine/cosine positional encoding values
+ * (used in RoPE, Rotary Position Embedding) for attention layers.
+ *
+ * This function computes and caches the sin/cos values of
+ * θ = position * theta_scale for RoPE encoding. The cache is shared
+ * across attention layers, and only the first attention layer will
+ * trigger initialization. The cache includes repeated sin/cos values
+ * with different repeat methods depending on the @param is_neox flag.
+ *
+ * Steps performed by this function:
+ * 1. Identify whether the target tensor belongs to Q/K in attention
+ * and restrict computation to the first layer only.
+ * 2. Initialize the theta scale array (arange → power → freq scaling).
+ * 3. Allocate sin/cos caches if the max prompt length increases.
+ * 4. Compute θ = position * theta_scale.
+ * 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
+ * 6. Expand sin/cos values by repeat or repeat_interleave depending
+ * on whether @param is_neox is enabled.
+ *
+ * @param ctx The CANN backend context, holding memory pool,
+ * stream, and persistent buffers for rope init/cache.
+ * @param dst The destination ggml_tensor whose computation
+ * depends on the RoPE values (usually Qcur/Kcur).
+ * @param sin_tensor_buffer Pre-allocated buffer for storing repeated sin values.
+ * @param cos_tensor_buffer Pre-allocated buffer for storing repeated cos values.
+ * @param theta_scale Scalar exponent base for computing theta scale values.
+ * @param freq_scale Frequency scaling factor, applied to theta scale.
+ * @param attn_factor Attention scaling factor, applied to sin/cos.
+ * @param is_neox Whether to use Neox-style repeat strategy
+ * (dim expansion vs repeat_interleave).
+ */
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
- aclTensor* acl_cos_repeat_tensor,
- aclTensor* acl_sin_repeat_tensor,
+ void* sin_tensor_buffer, void* cos_tensor_buffer,
+ float* corr_dims, float ext_factor,
float theta_scale, float freq_scale,
float attn_factor, bool is_neox) {
// int sin/cos cache, cache has different repeat method depond on
@@ -2233,12 +2280,10 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
ggml_tensor* src1 = dst->src[1]; // position
ggml_tensor* src2 = dst->src[2]; // freq_factors
- GGML_TENSOR_BINARY_OP_LOCALS
-
- int64_t theta_scale_length = ne00 / 2;
+ int64_t theta_scale_length = src0->ne[0] / 2;
int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
- size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t),
- theta_scale_length * sizeof(float_t)};
+ size_t theta_scale_nb[] = {sizeof(float), sizeof(float), sizeof(float),
+ theta_scale_length * sizeof(float)};
GGML_ASSERT(src1->type == GGML_TYPE_I32);
int64_t position_length = src1->ne[0];
@@ -2248,123 +2293,173 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
int64_t theta_ne[] = {theta_scale_length, 1, position_length, 1};
size_t theta_nb[GGML_MAX_DIMS];
- theta_nb[0] = sizeof(float_t);
+ theta_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
}
- bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
- bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
-
- // used for accuracy testing
- bool is_attention = is_q || is_k;
-
- if(ctx.init_ptr == nullptr || !is_attention) {
- // theta_scale arange, [0,1,...,ne00/2 - 1]
- if(ctx.init_ptr != nullptr){
- ACL_CHECK(aclrtFree(ctx.init_ptr));
+ // theta_scale arange, [0,1,...,ne00/2 - 1]
+ aclTensor* acl_theta_scale_tensor = nullptr;
+ // cache theta scale
+ if (ctx.rope_cache.theta_scale_length != theta_scale_length ||
+ // theta_scale and freq_scale should not change during the current token inference process,
+ // so we can directly use == here instead of comparing the absolute difference.
+ ctx.rope_cache.theta_scale != theta_scale ||
+ ctx.rope_cache.freq_scale != freq_scale) {
+
+ ctx.rope_cache.theta_scale_length = theta_scale_length;
+ ctx.rope_cache.theta_scale = theta_scale;
+ ctx.rope_cache.freq_scale = freq_scale;
+
+ if (ctx.rope_cache.theta_scale_cache != nullptr) {
+ ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
}
- ACL_CHECK(aclrtMalloc(&ctx.init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
+ ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
- aclTensor* acl_theta_scale_tensor =
- ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
+ acl_theta_scale_tensor =
+ ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
+
float start = 0;
float step = 1;
- float stop = ne00 / 2;
- float n_elements = ne00 / 2;
+ float stop = theta_scale_length;
+ float n_elements = theta_scale_length;
aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
+ ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
+ aclTensor* acl_yarn_ramp_tensor = nullptr;
+ if (ext_factor != 0) {
+ // -rope_yarn_ramp
+ // const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
+ // return MIN(1, MAX(0, y)) - 1;
+ yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
+ void* yarn_ramp_buffer = yarn_ramp_allocator.get();
+ acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float_t),
+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
+ float zero_value = 0, one_value = 1;
+ float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
+ aclScalar* low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);
+ aclScalar* zero = aclCreateScalar(&zero_value, aclDataType::ACL_FLOAT);
+ aclScalar* one = aclCreateScalar(&one_value, aclDataType::ACL_FLOAT);
+ aclScalar* denom_safe = aclCreateScalar(&denom_safe_value, aclDataType::ACL_FLOAT);
+ aclScalar* ext_factor_sc = aclCreateScalar(&ext_factor, aclDataType::ACL_FLOAT);
+
+ GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor, low, one, acl_yarn_ramp_tensor);
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor, denom_safe);
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor, zero, zero);
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor, one);
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor, one, one);
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, ext_factor_sc);
+
+ // theta_interp = freq_scale * theta_extrap;
+ // theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+ // theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
+ // theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
+ // theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
+ //
+ // we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
+ // cache freq_scale + (freq_scale - 1) * ramp_mix
+ float freq_scale_1 = freq_scale - 1;
+ aclScalar* freq_scale_sc = aclCreateScalar(&freq_scale, aclDataType::ACL_FLOAT);
+ aclScalar* freq_scale_1_sc = aclCreateScalar(&freq_scale_1, aclDataType::ACL_FLOAT);
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, freq_scale_1_sc);
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor, freq_scale_sc, one);
+
+ ggml_cann_release_resources(ctx, low, zero, one, denom_safe, ext_factor_sc, freq_scale_sc, freq_scale_1_sc);
+ }
+
// power
aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
acl_theta_scale_tensor);
- // freq_scale
- if (freq_scale != 1) {
+ if (ext_factor != 0) {
+ aclnn_mul(ctx, acl_theta_scale_tensor, acl_yarn_ramp_tensor);
+ } else if (freq_scale != 1) {
aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
}
- // freq_factors
- if (src2) {
- aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
- src2->data, ggml_cann_type_mapping(src2->type),
- ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
- aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
- ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
- }
- // release
- ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
+ ggml_cann_release_resources(ctx, acl_yarn_ramp_tensor, acl_theta_scale);
+ } else {
+ // use cache
+ acl_theta_scale_tensor =
+ ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
}
- if(ctx.sin_ptr == nullptr) {
- int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
- ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
- ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
- }
- if(position_length > ctx.max_prompt_length) {
- ctx.max_prompt_length = position_length;
- int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
- ACL_CHECK(aclrtFree(ctx.sin_ptr));
- ACL_CHECK(aclrtFree(ctx.cos_ptr));
- ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
- ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
+ ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());
+ // freq_factors
+ if (src2) {
+ freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float));
+ void* freq_fac_res_ptr = freq_fac_res_allocator.get();
+ aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
+ src2->data, ggml_cann_type_mapping(src2->type),
+ ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
+ aclTensor* acl_freq_fac_res_tensor = ggml_cann_create_tensor(
+ freq_fac_res_ptr, ACL_FLOAT, sizeof(float),
+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
+ aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
+ std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);
+ ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
}
- bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
-
- if(is_fisrt_layer || !is_attention) {
+ // position
+ aclTensor* acl_position_tensor = ggml_cann_create_tensor(
+ src1->data, ggml_cann_type_mapping(src1->type),
+ ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
+
+ // power * position
+ int64_t theta_length = theta_scale_length * position_length;
+ ggml_cann_pool_alloc theta_allocator(ctx.pool(),
+ theta_length * sizeof(float));
+ void* theta_buffer = theta_allocator.get();
+
+ aclTensor* acl_theta_tensor =
+ ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float),
+ theta_ne, theta_nb, GGML_MAX_DIMS);
+ aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
+ acl_theta_tensor);
+
+ // sin/cos
+ ggml_cann_pool_alloc sin_allocator(ctx.pool(),
+ theta_length * sizeof(float));
+ void* sin_buffer = sin_allocator.get();
+ aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
+ sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb,
+ GGML_MAX_DIMS, ACL_FORMAT_ND);
+ aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
- aclTensor* acl_theta_scale_tensor =
- ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
- theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
+ ggml_cann_pool_alloc cos_allocator(ctx.pool(),
+ theta_length * sizeof(float));
+ void* cos_buffer = cos_allocator.get();
+ aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
+ cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb,
+ GGML_MAX_DIMS, ACL_FORMAT_ND);
+ aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
- // position
- aclTensor* acl_position_tensor = ggml_cann_create_tensor(
- src1->data, ggml_cann_type_mapping(src1->type),
- ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
-
- // power * position
- int64_t theta_length = theta_scale_length * position_length;
- ggml_cann_pool_alloc theta_allocator(ctx.pool(),
- theta_length * sizeof(float_t));
- void* theta_buffer = theta_allocator.get();
-
- aclTensor* acl_theta_tensor =
- ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
- theta_ne, theta_nb, GGML_MAX_DIMS);
- aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
- acl_theta_tensor);
-
- // sin/cos
- aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
- ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
- GGML_MAX_DIMS, ACL_FORMAT_ND);
- aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
-
- aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
- ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
- GGML_MAX_DIMS, ACL_FORMAT_ND);
- aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
-
- // release
- ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
- acl_theta_tensor, acl_sin_tensor, acl_cos_tensor);
+ if (ext_factor != 0) {
+ attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
- aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
- ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
- GGML_MAX_DIMS, ACL_FORMAT_ND);
- aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
- ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
- GGML_MAX_DIMS, ACL_FORMAT_ND);
-
// attn_factor
if (attn_factor != 1) {
aclnn_muls(ctx, acl_sin_tensor, attn_factor, nullptr, true);
aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
}
+ int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
+ size_t sin_reshape_nb[GGML_MAX_DIMS];
+ sin_reshape_nb[0] = sizeof(float);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
+ }
+ aclTensor* acl_sin_repeat_tensor =
+ ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float),
+ sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
+ aclTensor* acl_cos_repeat_tensor =
+ ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float),
+ sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
+
// repeat
if (is_neox) {
int64_t repeatsArray[] = {1, 1, 1, 2};
@@ -2380,8 +2475,9 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
num_repeats, output_size);
}
- // release
- ggml_cann_release_resources(ctx, acl_sin_tensor, acl_cos_tensor);
+ ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
+ acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor,
+ acl_cos_repeat_tensor);
}
#ifdef __cplusplus
@@ -2403,6 +2499,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// TODO: use ascendc
// Only test with LLAMA model.
ggml_tensor* src0 = dst->src[0]; // input
+ ggml_tensor* src1 = dst->src[1];
// param
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -2424,8 +2521,6 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// TODO: n_dims <= ne0
GGML_ASSERT(n_dims == ne0);
GGML_ASSERT(n_dims % 2 == 0);
- // TODO: ext_factor != 0
- GGML_ASSERT(ext_factor == 0);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
@@ -2435,28 +2530,29 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
- // init cos/sin cache
- ggml_cann_pool_alloc sin_allocator(
- ctx.pool(), ne00 * ne02 * sizeof(float_t));
- ggml_cann_pool_alloc cos_allocator(
- ctx.pool(), ne00 * ne02 * sizeof(float_t));
- void* sin_buffer = sin_allocator.get();
- void* cos_buffer = cos_allocator.get();
+ // sin/cos tensor length.
+ int64_t repeat_theta_length = src0->ne[0] * src1->ne[0];
+ ggml_cann_pool_alloc sin_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
+ ggml_cann_pool_alloc cos_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
+ void *sin_tensor_buffer = sin_tensor_allocator.get();
+ void *cos_tensor_buffer = cos_tensor_allocator.get();
+
+ // init ctx.rope_cos/rope_sin cache
+ aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer, corr_dims, ext_factor,
+ theta_scale, freq_scale, attn_factor, is_neox);
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
size_t sin_reshape_nb[GGML_MAX_DIMS];
- sin_reshape_nb[0] = sizeof(float_t);
+ sin_reshape_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
}
aclTensor* acl_sin_reshape_tensor =
- ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float_t),
+ ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_reshape_tensor =
- ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t),
+ ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
- aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
- theta_scale, freq_scale, attn_factor, is_neox);
aclTensor* acl_src = ggml_cann_create_tensor(src0);
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
@@ -2470,7 +2566,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
void* minus_one_scale_buffer = nullptr;
ggml_cann_pool_alloc roll_allocator(ctx.pool(), ggml_nbytes(src0));
ggml_cann_pool_alloc minus_one_scale_allocator(
- ctx.pool(), sizeof(float_t) * src0->ne[0]);
+ ctx.pool(), sizeof(float) * src0->ne[0]);
if (!is_neox) {
// roll input: [q0,q1,q2,q3,...] -> [q1,q0,q3,q2,...]
input_roll_buffer = roll_allocator.get();
@@ -2500,13 +2596,13 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
size_t minus_one_nb[GGML_MAX_DIMS];
- minus_one_nb[0] = sizeof(float_t);
+ minus_one_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
}
acl_minus_one_tensor = aclnn_values(
- ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0],
- minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1);
+ ctx, minus_one_scale_buffer, sizeof(float) * src0->ne[0],
+ minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float), 1);
int64_t dim = 3;
int64_t* index = new int64_t[src0->ne[0]];
for (int i = 0; i < src0->ne[0]; i++) {
@@ -2534,22 +2630,22 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
minus_one_scale_buffer = minus_one_scale_allocator.get();
int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
size_t minus_one_nb[GGML_MAX_DIMS];
- minus_one_nb[0] = sizeof(float_t);
+ minus_one_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
}
acl_minus_one_tensor = aclnn_values(
- ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0],
- minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1);
+ ctx, minus_one_scale_buffer, sizeof(float) * src0->ne[0],
+ minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float), 1);
// -1 * first half
int64_t first_half_ne[4] = {src0->ne[0] / 2, 1, 1, 1};
size_t first_half_nb[GGML_MAX_DIMS];
- first_half_nb[0] = sizeof(float_t);
+ first_half_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
first_half_nb[i] = first_half_nb[i - 1] * first_half_ne[i - 1];
}
aclTensor* acl_first_half_tensor = ggml_cann_create_tensor(
- minus_one_scale_buffer, ACL_FLOAT, sizeof(float_t), first_half_ne,
+ minus_one_scale_buffer, ACL_FLOAT, sizeof(float), first_half_ne,
first_half_nb, GGML_MAX_DIMS);
bool inplace = true;
float scale = -1;
@@ -2589,28 +2685,28 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// TODO: ne0 != n_dims in mode2
} else if (src0->type == GGML_TYPE_F16) {
size_t input_fp32_nb[GGML_MAX_DIMS];
- input_fp32_nb[0] = sizeof(float_t);
+ input_fp32_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
input_fp32_nb[i] = input_fp32_nb[i - 1] * dst->ne[i - 1];
}
ggml_cann_pool_alloc fp32_allocator1(
- ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
+ ctx.pool(), ggml_nelements(dst) * sizeof(float));
void* input_fp32_buffer1 = fp32_allocator1.get();
aclTensor* input_fp32_tensor1 = ggml_cann_create_tensor(
- input_fp32_buffer1, ACL_FLOAT, sizeof(float_t), dst->ne,
+ input_fp32_buffer1, ACL_FLOAT, sizeof(float), dst->ne,
input_fp32_nb, GGML_MAX_DIMS);
ggml_cann_pool_alloc fp32_allocator2(
- ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
+ ctx.pool(), ggml_nelements(dst) * sizeof(float));
void* input_fp32_buffer2 = fp32_allocator2.get();
aclTensor* input_fp32_tensor2 = ggml_cann_create_tensor(
- input_fp32_buffer2, ACL_FLOAT, sizeof(float_t), dst->ne,
+ input_fp32_buffer2, ACL_FLOAT, sizeof(float), dst->ne,
input_fp32_nb, GGML_MAX_DIMS);
ggml_cann_pool_alloc fp32_allocator(
- ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
+ ctx.pool(), ggml_nelements(dst) * sizeof(float));
output_fp32_buffer = fp32_allocator.get();
aclTensor* output_fp32_tensor = ggml_cann_create_tensor(
- output_fp32_buffer, ACL_FLOAT, sizeof(float_t), dst->ne,
+ output_fp32_buffer, ACL_FLOAT, sizeof(float), dst->ne,
input_fp32_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor, input_fp32_tensor1);
aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor,
@@ -2707,8 +2803,6 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
aclIntArray *padding = aclCreateIntArray(paddingVal, 1);
int64_t dilationVal[] = {1};
aclIntArray *dilation = aclCreateIntArray(dilationVal, 1);
- bool transposed = true;
- int64_t groups = 1;
int8_t cubeMathType = 0;
#ifdef ASCEND_310P
@@ -2716,7 +2810,7 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
#endif
GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input, acl_weight, nullptr, stride,
- padding, dilation, transposed, padding, groups, acl_dst, cubeMathType);
+ padding, dilation, true, padding, 1, acl_dst, cubeMathType);
ggml_cann_release_resources(ctx, acl_weight, acl_dst, stride, padding, dilation);
}
@@ -2825,174 +2919,49 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
*/
static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
//dst [M, K, N, 1]
- ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
- ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
+ ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] -> [D, M, K, 1]
+ ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 -> [D, 1, K, 1]
ggml_tensor * ids = dst->src[2]; //ids [K, N]
- GGML_TENSOR_BINARY_OP_LOCALS
-
- // copy index from npu to cpu
- int64_t n_as = ne02; // A
- int64_t n_ids = ids->ne[0]; // K
-
- std::vector ids_host(ggml_nbytes(ids));
- ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids),
- ACL_MEMCPY_DEVICE_TO_HOST);
- ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
-
- char * src0_original = (char *) src0->data;
- char * src1_original = (char *) src1->data;
- char * dst_original = (char *) dst->data;
- size_t ori_src0_nb[4] = {nb00, nb01, nb02, nb03};
-
- // src0 is F16, src1 is F32, dst is F32
- ggml_cann_pool_alloc src0_cast_allocator;
- if (src0->type == GGML_TYPE_F16) {
- src0_cast_allocator.alloc(ctx.pool(), sizeof(float) * ggml_nelements(src0));
- void* src0_cast_buf = src0_cast_allocator.get();
-
- size_t cast_nb[GGML_MAX_DIMS];
- cast_nb[0] = sizeof(float_t);
- for (int i = 1; i < GGML_MAX_DIMS; i++) {
- cast_nb[i] = cast_nb[i - 1] * src0->ne[i - 1];
- }
-
- aclTensor* acl_src0_f16 = ggml_cann_create_tensor(src0);
- aclTensor* acl_cast = ggml_cann_create_tensor(src0_cast_buf,
- ACL_FLOAT, sizeof(float), src0->ne, cast_nb, 4);
- GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src0_f16, ACL_FLOAT, acl_cast);
- ggml_cann_release_resources(ctx, acl_cast, acl_src0_f16);
-
- src0_original = (char *) src0_cast_buf;
- memcpy(ori_src0_nb, cast_nb, sizeof(ori_src0_nb));
- }
-
-#ifdef ASCEND_310P
- ggml_tensor src0_row = *src0;
- ggml_tensor src1_row = *src1;
- ggml_tensor dst_row = *dst;
-
- if (src0->type == GGML_TYPE_F16) {
- src0_row.type = GGML_TYPE_F32;
- }
-
- // src0_row [D, M, 1, 1] weight without permute
- src0_row.ne[2] = 1;
- src0_row.ne[3] = 1;
- src0_row.nb[0] = ori_src0_nb[0];
- src0_row.nb[1] = ori_src0_nb[1];
- src0_row.nb[2] = ori_src0_nb[1];
- src0_row.nb[3] = ori_src0_nb[1];
-
- // src1_row [D, 1, 1, 1] -> input
- src1_row.ne[1] = 1;
- src1_row.ne[2] = 1;
- src1_row.ne[3] = 1;
- src1_row.nb[2] = nb11;
- src1_row.nb[3] = nb11;
-
- // dst_row [M, 1, 1, 1] -> out
- dst_row.ne[1] = 1;
- dst_row.ne[2] = 1;
- dst_row.ne[3] = 1;
- dst_row.nb[2] = nb1;
- dst_row.nb[3] = nb1;
-
- //create weight for one row
- for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
- for (int64_t id = 0; id < n_ids; id++) {
- // expert index
- int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
- GGML_ASSERT(i02 >= 0 && i02 < n_as);
-
- // If B = 1 (broadcast), always use 0; otherwise, use id.
- int64_t i11 = (ne11 == 1 ? 0 : id);
- int64_t i12 = iid1;
-
- int64_t i1 = id;
- int64_t i2 = i12;
+ GGML_ASSERT(src0->ne[3] == 1);
+ GGML_ASSERT(src1->ne[3] == 1);
+ GGML_ASSERT(dst->ne[3] == 1);
- void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2];
- void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
- void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
+ int64_t batch = src1->ne[2];
+ GGML_ASSERT(batch == ids->ne[1]);
- src0_row.data = src0_tmp_ptr;
- src1_row.data = src1_tmp_ptr;
- dst_row.data = dst_tmp_ptr;
- dst_row.src[0] = &src0_row;
- dst_row.src[1] = &src1_row;
+ ggml_cann_pool_alloc export_allocator(ctx.pool(), src0->ne[0] * src0->ne[1] * ids->ne[0] * ggml_element_size(src0));
+ void* export_ptr = export_allocator.get();
+ for (int64_t i = 0; i < batch; i++) {
+ aclTensor *select_index = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, i * ids->nb[1]);
+ aclTensor *export_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3);
- ggml_cann_mul_mat(ctx, &dst_row);
+ int64_t select_export_ne[] = {src0->ne[0], src0->ne[1], ids->ne[0]};
+ size_t select_export_nb[3];
+ select_export_nb[0] = src0->nb[0];
+ for (int k = 1;k < 3; k++) {
+ select_export_nb[k] = select_export_nb[k-1] * select_export_ne[k-1];
}
- }
- return;
-#endif
-
- std::vector src0_tensor_vec;
- std::vector src1_tensor_vec;
- std::vector dst_tensor_vec;
- for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
- for (int64_t id = 0; id < n_ids; id++) {
- // src0_row [M, D] -> weight && permute
- int64_t src0_ne[2] = {ne01, ne00};
- size_t src0_nb[2] = {ori_src0_nb[1], ori_src0_nb[0]};
- // src1_row [D, 1] -> input
- int64_t src1_ne[2] = {ne10, 1};
- size_t src1_nb[2] = {nb10, nb11};
- // dst_row [M, 1] -> out
- int64_t dst_ne[2] = {ne0, 1};
- size_t dst_nb[2] = {nb0, nb1};
-
- // expert index
- int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
- GGML_ASSERT(i02 >= 0 && i02 < n_as);
- // If B = 1 (broadcast), always use 0; otherwise, use id.
- int64_t i11 = (ne11 == 1 ? 0 : id);
- int64_t i12 = iid1;
+ aclTensor *select_export = ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), select_export_ne, select_export_nb, 3);
+ GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, export_weight, 0, select_index, select_export);
- int64_t i1 = id;
- int64_t i2 = i12;
+ int64_t select_transpose_ne[] = {select_export_ne[1], select_export_ne[0], select_export_ne[2]};
+ size_t select_transpose_nb[] = {select_export_nb[1], select_export_nb[0], select_export_nb[2]};
+ aclTensor *select_export_transpose = ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), select_transpose_ne, select_transpose_nb, 3);
- void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2];
- void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
- void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
+ int64_t active_tensor_ne[] = {src1->ne[0], 1, src1->ne[1]};
+ size_t active_tensor_nb[] = {src1->nb[0], src1->nb[1], src1->nb[1]};
+ aclTensor *active_tensor = ggml_cann_create_tensor(src1, active_tensor_ne, active_tensor_nb, 3, ACL_FORMAT_ND, i * src1->nb[2]);
- aclTensor* acl_src0 = ggml_cann_create_tensor(src0_tmp_ptr,
- ACL_FLOAT, sizeof(float),
- src0_ne, src0_nb, 2);
- aclTensor* acl_src1 = ggml_cann_create_tensor(src1_tmp_ptr,
- ACL_FLOAT, sizeof(float),
- src1_ne, src1_nb, 2);
- aclTensor* acl_dst = ggml_cann_create_tensor(dst_tmp_ptr,
- ACL_FLOAT, sizeof(float),
- dst_ne, dst_nb, 2);
+ int64_t dst_ne[] = {dst->ne[0], 1, dst->ne[1]};
+ size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[1]};
+ aclTensor *acl_dst = ggml_cann_create_tensor(dst, dst_ne,dst_nb, 3, ACL_FORMAT_ND, i * dst->nb[2]);
- src0_tensor_vec.push_back(acl_src0);
- src1_tensor_vec.push_back(acl_src1);
- dst_tensor_vec.push_back(acl_dst);
- }
- }
+ GGML_CANN_CALL_ACLNN_OP(ctx, BatchMatMul, active_tensor, select_export_transpose, acl_dst, 2);
- size_t GROUP_SIZE = 128;
- // GroupedMatmulV3 required tensor_list.size < 128
- for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
- // split and call GroupedMatmulV3
- size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
- std::vector src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
- std::vector src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
- std::vector dst_tensor_vec_split(dst_tensor_vec.begin() + i, dst_tensor_vec.begin() + end);
-
- aclTensorList* src0_tensor_list = aclCreateTensorList(src0_tensor_vec_split.data(), src0_tensor_vec_split.size());
- aclTensorList* src1_tensor_list = aclCreateTensorList(src1_tensor_vec_split.data(), src1_tensor_vec_split.size());
- aclTensorList* dst_tensor_list = aclCreateTensorList(dst_tensor_vec_split.data(), dst_tensor_vec_split.size());
-
- GGML_CANN_CALL_ACLNN_OP(ctx, GroupedMatmulV3, src1_tensor_list, src0_tensor_list,
- nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, -1, dst_tensor_list);
-
- ggml_cann_release_resources(ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list);
+ ggml_cann_release_resources(ctx, select_index, export_weight, select_export, active_tensor, acl_dst, select_export_transpose);
}
- return;
}
/**
@@ -3141,11 +3110,38 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
- ggml_tensor* src0 = dst->src[0]; // q, fp32
- ggml_tensor* src1 = dst->src[1]; // k, fp16
- ggml_tensor* src2 = dst->src[2]; // v, fp16
+ ggml_tensor* src0 = dst->src[0]; // q, fp32 | B, N, S, D (uncont) -> B, S, N, D (cont)
+ ggml_tensor* src1 = dst->src[1]; // k, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
+ ggml_tensor* src2 = dst->src[2]; // v, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
ggml_tensor* src3 = dst->src[3]; // mask, fp16
+ // B, N, S, D (uncont) -> B, S, N, D (cont)
+ int64_t src0_bsnd_ne[GGML_MAX_DIMS];
+ memcpy(src0_bsnd_ne, src0->ne, GGML_MAX_DIMS * sizeof(int64_t));
+ size_t src0_bsnd_nb[GGML_MAX_DIMS];
+ memcpy(src0_bsnd_nb, src0->nb, GGML_MAX_DIMS * sizeof(size_t));
+ int64_t src1_bsnd_ne[GGML_MAX_DIMS];
+ memcpy(src1_bsnd_ne, src1->ne, GGML_MAX_DIMS * sizeof(int64_t));
+ size_t src1_bsnd_nb[GGML_MAX_DIMS];
+ memcpy(src1_bsnd_nb, src1->nb, GGML_MAX_DIMS * sizeof(size_t));
+ int64_t src2_bsnd_ne[GGML_MAX_DIMS];
+ memcpy(src2_bsnd_ne, src2->ne, GGML_MAX_DIMS * sizeof(int64_t));
+ size_t src2_bsnd_nb[GGML_MAX_DIMS];
+ memcpy(src2_bsnd_nb, src2->nb, GGML_MAX_DIMS * sizeof(size_t));
+
+ auto transpose12 = [](int64_t* ne, size_t* nb) {
+ int64_t ne_tmp = ne[1];
+ size_t nb_tmp = nb[1];
+ ne[1] = ne[2];
+ nb[1] = nb[2];
+ ne[2] = ne_tmp;
+ nb[2] = nb_tmp;
+ };
+
+ transpose12(src0_bsnd_ne, src0_bsnd_nb);
+ transpose12(src1_bsnd_ne, src1_bsnd_nb);
+ transpose12(src2_bsnd_ne, src2_bsnd_nb);
+
float maxBias = 0.0f;
float scaleValue = 1.0f;
float logitSoftcap = 0.0f;
@@ -3167,11 +3163,12 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
void* src0_f16_buffer = nullptr;
if(ggml_cann_type_mapping(src0->type) != faDataType){
- aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0);
+ aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne,
+ src0_bsnd_nb, GGML_MAX_DIMS);
src0_f16_buffer = src0_f16_allocator.alloc(
ggml_nelements(src0) * faElemSize);
- int64_t* src0_f16_ne = src0->ne;
+ int64_t* src0_f16_ne = src0_bsnd_ne;
size_t src0_f16_nb[GGML_MAX_DIMS];
src0_f16_nb[0] = sizeof(uint16_t);
for(int i = 1; i < GGML_MAX_DIMS; ++i){
@@ -3185,20 +3182,23 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
ggml_cann_release_resources(ctx, acl_src0_f32_tensor);
}else{
- acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
+ acl_src0_f16_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne,
+ src0_bsnd_nb, GGML_MAX_DIMS);
}
// Step 2: create the acl tensors for src1 (Key), src2 (Value),
// and the direct output from FusedInferAttention
- acl_src1_f16_tensor = ggml_cann_create_tensor(src1);
- acl_src2_f16_tensor = ggml_cann_create_tensor(src2);
+ acl_src1_f16_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne,
+ src1_bsnd_nb, GGML_MAX_DIMS);
+ acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne,
+ src2_bsnd_nb, GGML_MAX_DIMS);
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
void* out_f16_buffer = out_f16_allocator.alloc(
ggml_nelements(dst) * faElemSize);
- int64_t* out_f16_ne = src0->ne;
+ int64_t* out_f16_ne = src0_bsnd_ne;
size_t out_f16_nb[GGML_MAX_DIMS];
out_f16_nb[0] = faElemSize;
for(int i = 1; i < GGML_MAX_DIMS; ++i){
@@ -3212,88 +3212,81 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
// Step 3: create the PSEShift tensor if needed
// this tensor is considered as mask (f16) in the llama.cpp
-
aclTensor* bcast_pse_tensor = nullptr;
- int64_t bcast_pse_ne[GGML_MAX_DIMS];
- size_t bcast_pse_nb[GGML_MAX_DIMS];
ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
- void* bcast_pse_buffer = nullptr;
-
if(src3 != nullptr){
- bcast_pse_buffer = bcast_pse_allocator.alloc(
- ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
-
- if(src0->ne[1] > 1){
- // Case 1: broadcast pse for prefill stage with multiple head
- aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
- bcast_pse_ne[0] = src3->ne[0];
- bcast_pse_ne[1] = src3->ne[1];
- bcast_pse_ne[2] = src0->ne[2];
- bcast_pse_ne[3] = src3->ne[3];
+ // Construct the truncated pse tensor (common for prefill/decode)
+ int64_t trunc_pse_ne[GGML_MAX_DIMS] = {
+ src3->ne[0], // D
+ src0->ne[1], // S (number of Q tokens)
+ src3->ne[2], // mask N
+ src3->ne[3] // B
+ };
+ size_t* trunc_pse_nb = src3->nb;
+
+ aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
+ src3->data, ACL_FLOAT16, sizeof(uint16_t),
+ trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS
+ );
+ int64_t bcast_pse_ne[GGML_MAX_DIMS];
+ size_t bcast_pse_nb[GGML_MAX_DIMS];
+ bcast_pse_ne[0] = src3->ne[0]; // D
+ bcast_pse_ne[1] = src0->ne[1]; // S
+ bcast_pse_ne[2] = src0->ne[2]; // N (num_heads)
+ bcast_pse_ne[3] = src3->ne[3]; // B
+ if (maxBias == 0.0f) {
+ // When maxBias == 0.0f, use nb = 0 reduce once repeat (Qwen2)
+ // Construct the bcast tensor (simulate repeat on the head dimension using stride=0)
bcast_pse_nb[0] = sizeof(uint16_t);
- for(int i = 1; i < GGML_MAX_DIMS; ++i){
- bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
- }
+ bcast_pse_nb[1] = bcast_pse_nb[0] * bcast_pse_ne[0];
+ bcast_pse_nb[2] = 0; // <---- the head dimension shares the same data
+ bcast_pse_nb[3] = src3->nb[3];
bcast_pse_tensor = ggml_cann_create_tensor(
- bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
- bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
-
- int64_t repeats[] = {1, src0->ne[2], 1, 1};
- aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
-
- ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
- }else{
- // Case 2: trunc the first row and broadcast pse for decode stage with multiple head
- int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
- size_t* trunc_pse_nb = src3->nb;
-
- aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
src3->data, ACL_FLOAT16, sizeof(uint16_t),
- trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
-
- bcast_pse_ne[0] = src3->ne[0];
- bcast_pse_ne[1] = src0->ne[1];
- bcast_pse_ne[2] = src0->ne[2];
- bcast_pse_ne[3] = src3->ne[3];
+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
+ );
+ ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
+ } else {
bcast_pse_nb[0] = sizeof(uint16_t);
- for(int i = 1; i < GGML_MAX_DIMS; ++i){
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
}
+ void* bcast_pse_buffer = bcast_pse_allocator.alloc(
+ ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)
+ );
+
bcast_pse_tensor = ggml_cann_create_tensor(
bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
- bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
+ );
int64_t repeats[] = {1, src0->ne[2], 1, 1};
aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
- ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
- }
-
- // Compute the slope if needed. Derived from ggml_cann_softmax().
- if(maxBias != 0.0f){
// alibi
+ // Compute the slope if needed. Derived from ggml_cann_softmax().
const int64_t n_heads = src0->ne[2];
- ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
+ ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(uint16_t));
void* slope_buffer = slope_allocator.get();
- aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias);
+ aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias, GGML_TYPE_F16);
int64_t slope_ne[] = {1, 1, n_heads, 1};
size_t slope_nb[GGML_MAX_DIMS];
- slope_nb[0] = sizeof(float);
+ slope_nb[0] = sizeof(uint16_t);
for(int i = 1;ine[0]); // 1/sqrt(d)
int64_t preTokens = 65535;
int64_t nextTokens = 65535;
- char layout[5] = {'B', 'N', 'S', 'D', 0};
+ char layout[5] = {'B', 'S', 'N', 'D', 0};
int64_t sparseMode = 0;
int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2;
int64_t blockSize = 0;
@@ -3347,32 +3340,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
);
// Step 6: post-processing, permute and cast to f32
-
- int64_t new_dim[] = {0, 2, 1, 3};
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
-
- if(ggml_cann_type_mapping(dst->type) != faDataType){
- ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool());
- perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
- void* perm_out_f16_buffer = perm_out_f16_allocator.get();
-
- int64_t* perm_out_f16_ne = dst->ne;
- size_t perm_out_f16_nb[GGML_MAX_DIMS];
- perm_out_f16_nb[0] = faElemSize;
- for(int i = 1; i < GGML_MAX_DIMS; ++i){
- perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1];
- }
- aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor(
- perm_out_f16_buffer, faDataType, faElemSize,
- perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
- aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
- aclnn_cast(ctx,
- acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
- ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor);
- }else{
- // only need to permute
- aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
- }
+ // TODO: when dst is fp16, don't need cast
+ aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
acl_src1_f16_tensor,
acl_src2_f16_tensor,
diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h
index 5858bd3f6a197..a041a157c333a 100755
--- a/ggml/src/ggml-cann/common.h
+++ b/ggml/src/ggml-cann/common.h
@@ -360,6 +360,30 @@ struct ggml_cann_graph {
};
#endif // USE_ACL_GRAPH
+struct ggml_cann_rope_cache {
+ ~ggml_cann_rope_cache() {
+ if(theta_scale_cache != nullptr) {
+ ACL_CHECK(aclrtFree(theta_scale_cache));
+ }
+ }
+
+ void* theta_scale_cache = nullptr;
+ int64_t theta_scale_length = 0;
+ float theta_scale = 0.0f;
+ float freq_scale = 0.0f;
+};
+
+struct ggml_cann_tensor_cache {
+ ~ggml_cann_tensor_cache() {
+ if(cache != nullptr) {
+ ACL_CHECK(aclrtFree(cache));
+ }
+ }
+
+ void* cache = nullptr;
+ int64_t size = 0;
+};
+
/**
* @brief Context for managing CANN backend operations.
*/
@@ -368,21 +392,18 @@ struct ggml_backend_cann_context {
std::string name; /**< Name of the device. */
std::string description; /**< Description of the device. */
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
- void* init_ptr = nullptr;
- void* sin_ptr = nullptr;
- void* cos_ptr = nullptr;
- int64_t max_prompt_length = 65536;
#ifdef USE_ACL_GRAPH
/// Cached CANN ACL graph used for executing the current ggml computation graph.
std::unique_ptr cann_graph;
+ bool acl_graph_mode = true;
#endif
cann_task_queue task_queue;
bool async_mode;
- bool support_set_rows;
- void* f32_zero_cache = nullptr;
- void* f32_one_cache = nullptr;
- int64_t f32_zero_cache_element = 0;
- int64_t f32_one_cache_element = 0;
+ // Rope Cache
+ ggml_cann_rope_cache rope_cache;
+ // Constant Pool
+ ggml_cann_tensor_cache rms_norm_one_tensor_cache;
+ ggml_cann_tensor_cache rms_norm_zero_tensor_cache;
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
@@ -398,14 +419,13 @@ struct ggml_backend_cann_context {
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
device, async_mode ? "ON" : "OFF");
-
- support_set_rows = parse_bool(get_env("LLAMA_SET_ROWS").value_or(""));
- GGML_LOG_INFO("%s: LLAMA_SET_ROWS is %s\n", __func__, support_set_rows ? "ON" : "OFF");
-
- if (!support_set_rows) {
- GGML_LOG_INFO("%s: CANN Graph currently only supports execution when LLAMA_SET_ROWS is ON. "
- "Falling back to eager mode.\n", __func__);
- }
+#ifdef USE_ACL_GRAPH
+ acl_graph_mode = !(parse_bool(get_env("GGML_CANN_DISABLE_ACL_GRAPH").value_or("")));
+ GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n",
+ __func__, device,
+ acl_graph_mode ? "GRAPH" : "EAGER",
+ acl_graph_mode ? "acl graph enabled" : "acl graph disabled");
+#endif
}
/**
@@ -422,15 +442,6 @@ struct ggml_backend_cann_context {
ACL_CHECK(aclrtDestroyStream(streams[i]));
}
}
- if(init_ptr != nullptr) {
- ACL_CHECK(aclrtFree(init_ptr));
- }
- if(sin_ptr != nullptr) {
- ACL_CHECK(aclrtFree(sin_ptr));
- }
- if(cos_ptr != nullptr) {
- ACL_CHECK(aclrtFree(cos_ptr));
- }
}
/**
diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp
index cb8af42ebf956..64fb2beff0aef 100755
--- a/ggml/src/ggml-cann/ggml-cann.cpp
+++ b/ggml/src/ggml-cann/ggml-cann.cpp
@@ -1155,7 +1155,7 @@ namespace {
* @note The workspace buffer used in this function is managed globally and reused
* across calls. This reduces overhead from repeated memory allocation and deallocation.
*/
-static void weight_format_to_nz(ggml_tensor *tensor, const void *data, size_t offset) {
+static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) {
aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne,
tensor->nb, 2, ACL_FORMAT_ND, offset);
uint64_t workspaceSize = 0;
@@ -1203,7 +1203,7 @@ static void ggml_backend_cann_buffer_set_tensor(
if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
GGML_ASSERT(tensor->ne[2] == 1);
GGML_ASSERT(tensor->ne[3] == 1);
- weight_format_to_nz(tensor, data, offset);
+ weight_format_to_nz(tensor, offset);
}
} else {
void *transform_buffer = malloc(size);
@@ -2247,12 +2247,12 @@ static enum ggml_status ggml_backend_cann_graph_compute(
(ggml_backend_cann_context*)backend->context;
ggml_cann_set_device(cann_ctx->device);
release_nz_workspace();
+
#ifdef USE_ACL_GRAPH
bool use_cann_graph = true;
bool cann_graph_update_required = false;
- // check environment LLAMA_SET_ROWS
- if (!cann_ctx->support_set_rows) {
+ if (!cann_ctx->acl_graph_mode) {
use_cann_graph = false;
}
@@ -2336,7 +2336,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
#ifdef ASCEND_310P
- // Q4 && Q8 per group is not suppor on 310p device
+ // Q4 && Q8 per group is not support on 310p device
return false;
#endif
// only support contiguous for quantized types.
@@ -2354,7 +2354,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
#ifdef ASCEND_310P
- // Q4 && Q8 per group is not suppor on 310p device
+ // Q4 && Q8 per group is not support on 310p device
return false;
#endif
// only support contiguous for quantized types.
@@ -2405,16 +2405,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
}
case GGML_OP_ROPE: {
// TODO: with ops-test v == 1
- float ext_factor = 0.0f;
- memcpy(&ext_factor, (const float *) op->op_params + 7, sizeof(float));
// TODO: n_dims <= ne0
if (op->src[0]->ne[0] != op->op_params[1]) {
return false;
}
- // TODO: ext_factor != 0
- if (ext_factor != 0) {
- return false;
- }
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
@@ -2423,10 +2417,11 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
if (mode & GGML_ROPE_TYPE_VISION) {
return false;
}
-
+#ifdef ASCEND_310P
if(!ggml_is_contiguous(op->src[0])){
return false;
}
+#endif
return true;
}
case GGML_OP_UPSCALE: {
@@ -2488,15 +2483,17 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_ARGMAX:
case GGML_OP_COS:
case GGML_OP_SIN:
- case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_LOG:
case GGML_OP_MEAN:
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL:
return true;
+ case GGML_OP_CONV_TRANSPOSE_1D:
+ // TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255.
+ return (op->src[0]->ne[0] - 1) <= 255;
case GGML_OP_SCALE:
float bias;
- memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
+ memcpy(&bias, (const float *)(op->op_params) + 1, sizeof(float));
return bias == 0.0f; // TODO: support bias != 0.0f
case GGML_OP_SOFT_MAX:
// TODO: support attention sinks [TAG_ATTN_SINKS]
@@ -2505,6 +2502,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
}
return true;
case GGML_OP_FLASH_ATTN_EXT:{
+#ifdef ASCEND_310P
+ // FA not support on 310p device
+ return false;
+#endif
// derived from [ggml-cuda.cu]
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
return false;
@@ -2523,15 +2524,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// different head sizes of K and V are not supported yet
return false;
}
- if (op->src[0]->ne[0] == 192) {
- return false;
- }
- if (op->src[0]->ne[0] == 576) {
- // DeepSeek MLA
+ if (op->src[0]->ne[0] % 16 != 0) {
+ // TODO: padding to support
return false;
}
float logitSoftcap = 0.0f;
- memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
+ memcpy(&logitSoftcap, (const float *)(op->op_params) + 2, sizeof(float));
if(logitSoftcap != 0.0f) {
return false;
}
diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt
index ce0a3e1285eb0..dd8c1cf67840e 100644
--- a/ggml/src/ggml-cpu/CMakeLists.txt
+++ b/ggml/src/ggml-cpu/CMakeLists.txt
@@ -433,15 +433,22 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
ggml-cpu/arch/riscv/quants.c
ggml-cpu/arch/riscv/repack.cpp
)
- if (GGML_RVV)
- if (GGML_XTHEADVECTOR)
- list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d)
- elseif (GGML_RV_ZFH)
- list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -mabi=lp64d)
- else()
- list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
+ set(MARCH_STR "rv64gc")
+ if (GGML_RV_ZFH)
+ string(APPEND MARCH_STR "_zfh")
+ endif()
+ if (GGML_XTHEADVECTOR)
+ string(APPEND MARCH_STR "_xtheadvector")
+ elseif (GGML_RVV)
+ string(APPEND MARCH_STR "_v")
+ if (GGML_RV_ZVFH)
+ string(APPEND MARCH_STR "_zvfh")
endif()
endif()
+ if (GGML_RV_ZICBOP)
+ string(APPEND MARCH_STR "_zicbop")
+ endif()
+ list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
message(STATUS "s390x detected")
list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c)
@@ -497,9 +504,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
# Fetch KleidiAI sources:
include(FetchContent)
- set(KLEIDIAI_COMMIT_TAG "v1.11.0")
+ set(KLEIDIAI_COMMIT_TAG "v1.13.0")
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
- set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2")
+ set(KLEIDIAI_ARCHIVE_MD5 "d82a8de939d9814621a5ba23907bdac1")
if (POLICY CMP0135)
cmake_policy(SET CMP0135 NEW)
@@ -555,6 +562,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
list(APPEND GGML_KLEIDIAI_SOURCES
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
@@ -576,7 +584,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
- ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c)
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c
+ ${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S)
set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
endif()
diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c
index 6c74417c90c1f..ee41a3502e82d 100644
--- a/ggml/src/ggml-cpu/arch/riscv/quants.c
+++ b/ggml/src/ggml-cpu/arch/riscv/quants.c
@@ -1270,29 +1270,40 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
- int tmp, tmp2, sumi;
+ float ftmp, ft2;
+ const uint8_t * restrict q40;
+ const uint8_t * restrict q41;
+ const uint8_t * restrict q42;
+ const uint8_t * restrict q43;
+ const int8_t * restrict q80;
+ const int8_t * restrict q81;
+ const int8_t * restrict q82;
+ const int8_t * restrict q83;
+ int s0, s1, s2, s3;
+
__asm__ __volatile__(
- "vsetivli zero, 12, e8, m1\n\t"
- "vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
- "vsetivli zero, 4, e32, m1\n\t"
+ "li %[s1], 8\n\t"
+ "vsetivli zero, 4, e32, m1, ta, ma\n\t"
+ "vle32.v v1, (%[s6b])\n\t"
+ "vslide1down.vx v1, v1, zero\n\t"
+ "vmv.v.x v16, zero\n\t"
"vslidedown.vi v2, v1, 2\n\t"
"vmv1r.v v3, v2\n\t"
"vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
- "vsetivli zero, 2, e32, m1\n\t"
+ "vsetivli zero, 2, e32, m1, ta, ma\n\t"
"vmv.v.i v4, 4\n\t"
"vand.vx v8, v1, %[kmask1]\n\t"
"vslide1up.vx v5, v4, zero\n\t" // {0, 4}
"vsrl.vi v6, v1, 6\n\t"
"vsrl.vv v7, v2, v5\n\t"
+ "vsse32.v v8, (%[utmp]), %[s1]\n\t"
"vand.vx v0, v6, %[kmask3]\n\t"
"vand.vx v2, v7, %[kmask2]\n\t"
"vsll.vi v6, v0, 4\n\t"
- "li %[t2], 8\n\t"
- "addi %[t1], %[utmp], 4\n\t"
+ "addi %[s0], %[utmp], 4\n\t"
"vor.vv v1, v6, v2\n\t"
- "vsse32.v v8, (%[utmp]), %[t2]\n\t"
- "vsse32.v v1, (%[t1]), %[t2]\n\t"
- "vsetivli zero, 8, e16, m1\n\t"
+ "vsse32.v v1, (%[s0]), %[s1]\n\t"
+ "vsetivli zero, 8, e16, m1, ta, ma\n\t"
"vle32.v v2, (%[bsums])\n\t"
"vnsrl.wi v0, v2, 0\n\t"
"vnsrl.wi v1, v2, 16\n\t"
@@ -1300,13 +1311,131 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
"vle8.v v3, (%[mins])\n\t"
"vzext.vf2 v4, v3\n\t"
"vwmul.vv v6, v4, v2\n\t"
+ "vsetivli zero, 4, e32, m1, ta, ma\n\t"
+ "vredsum.vs v0, v6, v16\n\t"
+ "vredsum.vs v0, v7, v0\n\t"
+ "vfcvt.f.x.v v0, v0\n\t"
+ "vfmv.f.s %[ftmp], v0\n\t"
+ "vsetivli zero, 16, e8, m1, ta, ma\n\t"
+ "vle8.v v0, (%[xs])\n\t"
+ "fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t"
+ "addi %[q40], %[xs], 64\n\t"
+ "addi %[q41], %[xs], 16\n\t"
+ "addi %[q42], %[xs], 32\n\t"
+ "addi %[q43], %[xs], 48\n\t"
+ "addi %[q80], %[ys], 64\n\t"
+ "vle8.v v1, (%[q41])\n\t"
+ "vle8.v v2, (%[q42])\n\t"
+ "addi %[q81], %[ys], 16\n\t"
+ "addi %[q41], %[q41], 64\n\t"
+ "addi %[q82], %[ys], 32\n\t"
+ "vle8.v v3, (%[q43])\n\t"
+ "vle8.v v8, (%[ys])\n\t"
+ "addi %[q42], %[q42], 64\n\t"
+ "addi %[q83], %[ys], 48\n\t"
+ "addi %[q43], %[q43], 64\n\t"
+ "vsrl.vi v4, v0, 4\n\t"
+ "vle8.v v9, (%[q81])\n\t"
+ "vle8.v v10, (%[q82])\n\t"
+ "vand.vi v0, v0, 0xF\n\t"
+ "addi %[q81], %[q81], 64\n\t"
+ "vsrl.vi v5, v1, 4\n\t"
+ "addi %[q82], %[q82], 64\n\t"
+ "vle8.v v11, (%[q83])\n\t"
+ "vle8.v v12, (%[q80])\n\t"
+ "vand.vi v1, v1, 0xF\n\t"
+ "addi %[q83], %[q83], 64\n\t"
+ "vsrl.vi v6, v2, 4\n\t"
+ "addi %[q80], %[q80], 64\n\t"
+ "vle8.v v13, (%[q81])\n\t"
+ "vle8.v v14, (%[q82])\n\t"
+ "vand.vi v2, v2, 0xF\n\t"
+ "addi %[q81], %[q81], 64\n\t"
+ "vsrl.vi v7, v3, 4\n\t"
+ "addi %[q82], %[q82], 64\n\t"
+ "vwmul.vv v16, v0, v8\n\t"
+ "vle8.v v15, (%[q83])\n\t"
+ "vle8.v v0, (%[q40])\n\t"
+ "vand.vi v3, v3, 0xF\n\t"
+ "addi %[q83], %[q83], 64\n\t"
+ "vwmul.vv v24, v2, v12\n\t"
+ "vwmul.vv v20, v4, v10\n\t"
+ "vwmul.vv v28, v6, v14\n\t"
+ "vwmacc.vv v16, v1, v9\n\t"
+ "vle8.v v1, (%[q41])\n\t"
+ "vle8.v v2, (%[q42])\n\t"
+ "vwmacc.vv v24, v3, v13\n\t"
+ "vwmacc.vv v20, v5, v11\n\t"
+ "vwmacc.vv v28, v7, v15\n\t"
+ "addi %[q40], %[q80], 64\n\t"
+ "addi %[q41], %[q81], 64\n\t"
+ "vle8.v v3, (%[q43])\n\t"
+ "vle8.v v8, (%[q80])\n\t"
+ "addi %[q42], %[q82], 64\n\t"
+ "addi %[q43], %[q83], 64\n\t"
+ "vsrl.vi v4, v0, 4\n\t"
+ "vle8.v v9, (%[q81])\n\t"
+ "vle8.v v10, (%[q82])\n\t"
+ "vand.vi v0, v0, 0xF\n\t"
+ "vsrl.vi v5, v1, 4\n\t"
+ "vsrl.vi v7, v3, 4\n\t"
+ "vand.vi v3, v3, 0xF\n\t"
+ "vle8.v v11, (%[q83])\n\t"
+ "vle8.v v12, (%[q40])\n\t"
+ "vand.vi v1, v1, 0xF\n\t"
+ "vsrl.vi v6, v2, 4\n\t"
+ "vand.vi v2, v2, 0xF\n\t"
+ "vwmul.vv v18, v0, v8\n\t"
+ "vle8.v v13, (%[q41])\n\t"
+ "vle8.v v14, (%[q42])\n\t"
+ "vwmul.vv v26, v2, v12\n\t"
+ "vwmul.vv v22, v4, v10\n\t"
+ "vwmul.vv v30, v6, v14\n\t"
+ "vwmacc.vv v18, v1, v9\n\t"
+ "vle8.v v15, (%[q43])\n\t"
+ "vwmacc.vv v26, v3, v13\n\t"
+ "vwmacc.vv v22, v5, v11\n\t"
+ "vwmacc.vv v30, v7, v15\n\t"
"vmv.v.x v0, zero\n\t"
- "vsetivli zero, 8, e32, m2\n\t"
- "vredsum.vs v0, v6, v0\n\t"
- "vmv.x.s %[sumi], v0"
- : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
- : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
- , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
+ "vsetivli zero, 16, e16, m2, ta, ma\n\t"
+ "vwredsum.vs v4, v16, v0\n\t"
+ "lbu %[s0], 0(%[scale])\n\t"
+ "vwredsum.vs v5, v20, v0\n\t"
+ "lbu %[s1], 1(%[scale])\n\t"
+ "vwredsum.vs v6, v24, v0\n\t"
+ "lbu %[s2], 2(%[scale])\n\t"
+ "vwredsum.vs v7, v28, v0\n\t"
+ "lbu %[s3], 3(%[scale])\n\t"
+ "vwredsum.vs v8, v18, v0\n\t"
+ "lbu %[q40], 4(%[scale])\n\t"
+ "vwredsum.vs v9, v22, v0\n\t"
+ "lbu %[q41], 5(%[scale])\n\t"
+ "vwredsum.vs v10, v26, v0\n\t"
+ "lbu %[q42], 6(%[scale])\n\t"
+ "vwredsum.vs v11, v30, v0\n\t"
+ "lbu %[q43], 7(%[scale])\n\t"
+ "vsetivli zero, 4, e32, m1, ta, ma\n\t"
+ "vmul.vx v0, v4, %[s0]\n\t"
+ "vmul.vx v1, v8, %[q40]\n\t"
+ "vmacc.vx v0, %[s1], v5\n\t"
+ "vmacc.vx v1, %[q41], v9\n\t"
+ "vmacc.vx v0, %[s2], v6\n\t"
+ "vmacc.vx v1, %[q42], v10\n\t"
+ "vmacc.vx v0, %[s3], v7\n\t"
+ "vmacc.vx v1, %[q43], v11\n\t"
+ "vfcvt.f.x.v v0, v0\n\t"
+ "vfcvt.f.x.v v1, v1\n\t"
+ "vfmv.f.s %[ft2], v0\n\t"
+ "vfmv.f.s %[ftmp], v1\n\t"
+ "fadd.s %[ft2], %[ft2], %[ftmp]\n\t"
+ "fmadd.s %[sumf], %[d], %[ft2], %[sumf]"
+ : [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2)
+ , [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3)
+ , [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43)
+ , [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83)
+ : [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales)
+ , [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
+ , [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin)
, [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
: "memory"
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
@@ -1314,59 +1443,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
);
- sumf -= dmin * sumi;
-
- const uint8_t * restrict q4 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
- sumi = 0;
- const uint8_t * scale = scales;
-
- for (int j = 0; j < QK_K/128; ++j) {
- int vl128 = 128, vl64 = 64, vl32 = 32;
- __asm__ __volatile__(
- "vsetvli zero, %[vl128], e8, m8\n\t"
- "vle8.v v8, (%[q8])\n\t"
- "vsetvli zero, %[vl64], e8, m4\n\t"
- "vle8.v v0, (%[q4])\n\t"
- "vsrl.vi v4, v0, 4\n\t"
- "vand.vi v0, v0, 0xF\n\t"
- "vsetvli zero, %[vl32], e8, m2\n\t"
- "vwmul.vv v28, v6, v14\n\t"
- "vwmul.vv v20, v4, v10\n\t"
- "vwmul.vv v24, v2, v12\n\t"
- "vwmul.vv v16, v0, v8\n\t"
- "vsetivli zero, 4, e32, m1\n\t"
- "vle8.v v2, (%[scale])\n\t"
- "vmv.v.x v0, zero\n\t"
- "vzext.vf4 v1, v2\n\t"
- "vsetvli zero, %[vl32], e16, m4\n\t"
- "vwredsum.vs v6, v24, v0\n\t"
- "vwredsum.vs v7, v28, v0\n\t"
- "vwredsum.vs v4, v16, v0\n\t"
- "vwredsum.vs v5, v20, v0\n\t"
- "vsetivli zero, 4, e32, m1\n\t"
- "vslideup.vi v6, v7, 1\n\t"
- "vslideup.vi v4, v5, 1\n\t"
- "vslideup.vi v4, v6, 2\n\t"
- "vmul.vv v8, v4, v1\n\t"
- "vredsum.vs v0, v8, v0\n\t"
- "vmv.x.s %[tmp], v0\n\t"
- "add %[sumi], %[sumi], %[tmp]"
- : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
- : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
- , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
- : "memory"
- , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
- , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
- , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
- , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
- );
-
- q4 += 64; q8 += 128; scale += 4;
- }
-
- sumf += d * sumi;
}
break;
default:
@@ -1693,6 +1769,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
case 128:
for (int i = 0; i < nb; ++i) {
+ __builtin_prefetch(&x[i + 1].d, 0, 1);
+
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
const uint8_t * restrict q6 = x[i].ql;
@@ -1701,23 +1779,59 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const int8_t * restrict scale = x[i].scales;
- int sum_t = 0;
- int t0;
+ int q6h;
+ float ftmp;
for (int j = 0; j < QK_K/128; ++j) {
__asm__ __volatile__(
+ "addi %[q6h], %[q6], 32\n\t"
+ "ld t0, 0(%[scale])\n\t"
+ "addi %[scale], %[scale], 8\n\t"
+ "slli t6, t0, 1 * 8\n\t"
+ "lb zero, 0(%[q6])\n\t"
+ "slli t5, t0, 2 * 8\n\t"
+ "slli t4, t0, 3 * 8\n\t"
+ "lb zero, 0(%[q6h])\n\t"
+ "slli t3, t0, 4 * 8\n\t"
+ "slli t2, t0, 5 * 8\n\t"
+ "lb zero, 0(%[qh])\n\t"
+ "lb zero, 31(%[q6h])\n\t"
+ "slli t1, t0, 6 * 8\n\t"
+ "srai a7, t0, 56\n\t"
"vsetvli zero, %[vl32], e8, m2\n\t"
+ "vle8.v v8, (%[q6])\n\t"
+ "srai t6, t6, 56\n\t"
+ "srai t5, t5, 56\n\t"
+ "srai t4, t4, 56\n\t"
+ "srai t3, t3, 56\n\t"
+ "vle8.v v10, (%[q6h])\n\t"
+ "addi %[q6], %[q6], 64\n\t"
+ "slli t0, t0, 7 * 8\n\t"
+ "srai t2, t2, 56\n\t"
+ "srai t1, t1, 56\n\t"
+ "srai t0, t0, 56\n\t"
"vle8.v v4, (%[qh])\n\t"
+ "vsrl.vi v12, v8, 4\n\t"
+ "vsrl.vi v14, v10, 4\n\t"
+ "lb zero, 0(%[q8])\n\t"
+ "vand.vi v8, v8, 0xF\n\t"
+ "vand.vi v10, v10, 0xF\n\t"
+ "lb zero, 32(%[q8])\n\t"
"vsll.vi v0, v4, 4\n\t"
"vsll.vi v2, v4, 2\n\t"
+ "lb zero, 64(%[q8])\n\t"
"vsrl.vi v6, v4, 2\n\t"
- "vsetvli zero, %[vl64], e8, m4\n\t"
- "vle8.v v8, (%[q6])\n\t"
- "vsrl.vi v12, v8, 4\n\t"
- "vand.vi v8, v8, 0xF\n\t"
- "vsetvli zero, %[vl128], e8, m8\n\t"
"vand.vx v0, v0, %[mask]\n\t"
+ "lb zero, 96(%[q8])\n\t"
+ "vand.vx v2, v2, %[mask]\n\t"
+ "vand.vx v4, v4, %[mask]\n\t"
+ "vand.vx v6, v6, %[mask]\n\t"
"vor.vv v8, v8, v0\n\t"
+ "lb zero, 127(%[q8])\n\t"
+ "vor.vv v10, v10, v2\n\t"
+ "vor.vv v12, v12, v4\n\t"
+ "vor.vv v14, v14, v6\n\t"
+ "vsetvli zero, %[vl128], e8, m8\n\t"
"vle8.v v0, (%[q8])\n\t"
"vsub.vx v8, v8, %[vl32]\n\t"
"vsetvli zero, %[vl64], e8, m4\n\t"
@@ -1734,34 +1848,34 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
"vwredsum.vs v13, v28, v0\n\t"
"vwredsum.vs v14, v30, v0\n\t"
"vsetivli zero, 4, e32, m1\n\t"
- "vslideup.vi v10, v9, 1\n\t"
- "vslideup.vi v8, v7, 1\n\t"
- "vslideup.vi v11, v12, 1\n\t"
- "vslideup.vi v13, v14, 1\n\t"
- "vslideup.vi v10, v8, 2\n\t"
- "vslideup.vi v11, v13, 2\n\t"
- "vsetivli zero, 8, e32, m2\n\t"
- "vle8.v v2, (%[scale])\n\t"
- "vsext.vf4 v4, v2\n\t"
- "vmul.vv v2, v4, v10\n\t"
- "vredsum.vs v0, v2, v0\n\t"
- "vmv.x.s %[t0], v0\n\t"
- "add %[sumi], %[sumi], %[t0]"
- : [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
- : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
+ "vmul.vx v0, v10, t0\n\t"
+ "vmul.vx v1, v9, t1\n\t"
+ "vmacc.vx v0, t2, v8\n\t"
+ "vmacc.vx v1, t3, v7\n\t"
+ "vmacc.vx v0, t4, v11\n\t"
+ "vmacc.vx v1, t5, v12\n\t"
+ "vmacc.vx v0, t6, v13\n\t"
+ "vmacc.vx v1, a7, v14\n\t"
+ "vadd.vv v0, v0, v1\n\t"
+ "vfcvt.f.x.v v0, v0\n\t"
+ "vfmv.f.s %[ftmp], v0\n\t"
+ "fmadd.s %[sumf], %[d], %[ftmp], %[sumf]"
+ : [q6] "+&r" (q6), [q6h] "=&r" (q6h)
+ , [scale] "+&r" (scale)
+ , [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp)
+ : [qh] "r" (qh), [q8] "r" (q8)
, [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
- , [mask] "r" (0x30)
+ , [mask] "r" (0x30), [d] "f" (d)
: "memory"
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
+ , "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7"
+ , "a6", "a5", "a4", "a3"
);
- q6 += 64; qh += 32; q8 += 128; scale += 8;
+ qh += 32; q8 += 128;
}
-
- sumf += d * sum_t;
-
}
break;
default:
diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h
index 1f6844e16cd34..e08c30a348aa1 100644
--- a/ggml/src/ggml-cpu/ggml-cpu-impl.h
+++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h
@@ -489,7 +489,7 @@ inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
/**
* @see https://github.com/ggml-org/llama.cpp/pull/14037
*/
-inline float vec_hsum(float32x4_t v) {
+inline static float vec_hsum(float32x4_t v) {
float32x4_t v_temp = v + vec_reve(v);
return v_temp[0] + v_temp[1];
}
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index 0d5d3a3440aaf..78ec189d4c646 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -3221,6 +3221,13 @@ void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
uint16x8_t v_y = vec_convert_to_fp16(v_yd, 0);
vec_xst(v_y, 0, (ggml_fp16_t *)(y + i));
}
+#elif defined(__riscv_zvfh)
+ for (int vl; i < n; i += vl) {
+ vl = __riscv_vsetvl_e32m2(n - i);
+ vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
+ vfloat16m1_t vy = __riscv_vfncvt_f_f_w_f16m1(vx, vl);
+ __riscv_vse16_v_f16m1((_Float16 *)&y[i], vy, vl);
+ }
#endif
for (; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(x[i]);
diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp
index ddd29d002d1ca..7ba659124ca27 100644
--- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp
+++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp
@@ -14,6 +14,7 @@
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
+#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
@@ -127,6 +128,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
},
+ /* .gemm_lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
+ },
/* SME GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
@@ -141,7 +148,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
},
- /* .lhs_info = */ {
+ /* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
@@ -173,6 +180,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
},
+ /* .gemm_lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
+ /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
+ },
/* SME GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
@@ -187,7 +200,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
},
- /* .lhs_info = */ {
+ /* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
@@ -222,6 +235,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
},
+ /* .gemm_lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
+ },
/* DOTPROD GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
@@ -236,7 +255,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
},
- /* .lhs_info = */ {
+ /* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
@@ -270,6 +289,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
},
+ /* .gemm_lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
+ },
/* i8mm GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
@@ -284,7 +309,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
},
- /* .lhs_info = */ {
+ /* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
@@ -319,6 +344,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
},
+ /* .gemm_lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
+ },
/* i8mm GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
@@ -333,7 +364,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
},
- /* .lhs_info = */ {
+ /* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
@@ -367,6 +398,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
},
+ /* .gemm_lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
+ },
/* DOTPROD GEMV */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
@@ -381,7 +418,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
},
- /* .lhs_info = */ {
+ /* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h
index bc8f33405d1fe..2ad6ad6fd0bfc 100644
--- a/ggml/src/ggml-cpu/kleidiai/kernels.h
+++ b/ggml/src/ggml-cpu/kleidiai/kernels.h
@@ -84,8 +84,11 @@ struct rhs_packing_info {
struct ggml_kleidiai_kernels {
kernel_info gemm;
+ lhs_packing_info gemm_lhs_info;
+
kernel_info gemv;
- lhs_packing_info lhs_info;
+ lhs_packing_info gemv_lhs_info;
+
rhs_packing_info rhs_info;
cpu_feature required_cpu;
diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
index dff8fa244a1c9..7a830448eb3e1 100644
--- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
+++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
@@ -123,7 +123,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
}
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
GGML_ASSERT(kernels);
- kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
+ bool is_gemv = op->src[1]->ne[1] == 1;
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
size_t k = op->src[0]->ne[0];
size_t n = op->src[0]->ne[1];
@@ -134,9 +136,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
size_t sr = kernel->get_sr();
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
- size = variant_call(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr);
+ size = variant_call(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
} else if (kernels->rhs_type == GGML_TYPE_F16) {
- size = variant_call(kernels->lhs_info.packed_size, m, k, mr, kr, sr) +
+ size = variant_call(lhs_info->packed_size, m, k, mr, kr, sr) +
variant_call(kernels->rhs_info.packed_size, n, k) +
k * n * sizeof(float) + n * sizeof(float);
} else {
@@ -173,7 +175,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
GGML_ASSERT(kernels);
- kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
+ bool is_gemv = src1->ne[1] == 1;
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
GGML_ASSERT(kernel);
const int nth = params->nth;
@@ -198,7 +202,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const int64_t kr = static_cast(kernel->get_kr());
const int64_t sr = static_cast(kernel->get_sr());
- const size_t lhs_packed_size = variant_call(kernels->lhs_info.packed_size, m, k, mr, kr, sr);
+ const size_t lhs_packed_size = variant_call(lhs_info->packed_size, m, k, mr, kr, sr);
const size_t rhs_packed_size = variant_call(kernels->rhs_info.packed_size, n, k);
const size_t kxn_size = k * n * sizeof(float);
const size_t bias_size = n * sizeof(float);
@@ -229,12 +233,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
const size_t lhs_offset = variant_call(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
- const size_t lhs_packed_offset = variant_call(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr);
+ const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, m_start, k, mr, kr, sr);
const void * src_ptr = static_cast(lhs_batch) + lhs_offset;
void * dst_ptr = static_cast(lhs_packed) + lhs_packed_offset;
- variant_call(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
+ variant_call(lhs_info->pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
}
}
@@ -306,8 +310,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
GGML_ASSERT(kernels);
- kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
- lhs_packing_info * lhs_info = &kernels->lhs_info;
+ bool is_gemv = src1->ne[1] == 1;
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
GGML_ASSERT(kernel);
diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
index 2be54c31b5f3e..2c4ad9d58b9f2 100644
--- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp
+++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
@@ -2169,94 +2169,117 @@ class tinyBLAS_Q0_PPC {
class tinyBLAS_PPC {
public:
tinyBLAS_PPC(int64_t k,
- const float *A, int64_t lda,
- const float *B, int64_t ldb,
- float *C, int64_t ldc,
+ const float * A, int64_t lda,
+ const float * B, int64_t ldb,
+ float * C, int64_t ldc,
int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
void matmul(int64_t m, int64_t n) {
- mnpack(0, m, 0, n);
+ int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
+ if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
+ matmul_tiled(m, n, mc, nc, kc);
+ } else {
+ mnpack(0, m, 0, n);
+ }
}
private:
- void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
+ inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
+ vec_t vec_C[4];
+ __builtin_mma_disassemble_acc(vec_C, ACC);
+ for (int I = 0; I < 4; I++) {
+ for (int J = 0; J < 4; J++) {
+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
+ }
+ }
+ }
- inline void vector_permute_store_4(vector float *src, float *vecOffset) {
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
- t1 = vec_mergeh(src[0], src[1]);
- t2 = vec_mergeh(src[2], src[3]);
- t3 = vec_mergel(src[0], src[1]);
- t4 = vec_mergel(src[2], src[3]);
+ inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
+ vec_t vec_C[4];
+ __builtin_mma_disassemble_acc(vec_C, ACC);
+ for (int I = 0; I < 4; I++) {
+ for (int J = 0; J < 4; J++) {
+ float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
+ *c_ptr += *((float *)&vec_C[I]+J);
+ }
+ }
+ }
- t5 = vec_xxpermdi(t1, t2, 0);
- t6 = vec_xxpermdi(t1, t2, 3);
- t7 = vec_xxpermdi(t3, t4, 0);
- t8 = vec_xxpermdi(t3, t4, 3);
+ inline void vector_permute_store_4(vector float * src, float * vecOffset) {
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
+ t1 = vec_mergeh(src[0], src[1]);
+ t2 = vec_mergeh(src[2], src[3]);
+ t3 = vec_mergel(src[0], src[1]);
+ t4 = vec_mergel(src[2], src[3]);
- vec_xst(t5, 0, vecOffset);
- vec_xst(t6, 0, vecOffset + 4);
- vec_xst(t7, 0, vecOffset + 8);
- vec_xst(t8, 0, vecOffset + 12);
- }
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t1, t2, 3);
+ t7 = vec_xxpermdi(t3, t4, 0);
+ t8 = vec_xxpermdi(t3, t4, 3);
- inline void vector_permute_store_8(vector float *src, float *vecOffset) {
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
- t1 = vec_mergeh(src[0], src[1]);
- t2 = vec_mergeh(src[2], src[3]);
- t3 = vec_mergeh(src[4], src[5]);
- t4 = vec_mergeh(src[6], src[7]);
+ vec_xst(t5, 0, vecOffset);
+ vec_xst(t6, 0, vecOffset + 4);
+ vec_xst(t7, 0, vecOffset + 8);
+ vec_xst(t8, 0, vecOffset + 12);
+ }
- t5 = vec_xxpermdi(t1, t2, 0);
- t6 = vec_xxpermdi(t3, t4, 0);
- t7 = vec_xxpermdi(t1, t2, 3);
- t8 = vec_xxpermdi(t3, t4, 3);
+ inline void vector_permute_store_8(vector float * src, float * vecOffset) {
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
+ t1 = vec_mergeh(src[0], src[1]);
+ t2 = vec_mergeh(src[2], src[3]);
+ t3 = vec_mergeh(src[4], src[5]);
+ t4 = vec_mergeh(src[6], src[7]);
- vec_xst(t5, 0, vecOffset);
- vec_xst(t6, 0, vecOffset + 4);
- vec_xst(t7, 0, vecOffset + 8);
- vec_xst(t8, 0, vecOffset + 12);
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t3, t4, 0);
+ t7 = vec_xxpermdi(t1, t2, 3);
+ t8 = vec_xxpermdi(t3, t4, 3);
- t1 = vec_mergel(src[0], src[1]);
- t2 = vec_mergel(src[2], src[3]);
- t3 = vec_mergel(src[4], src[5]);
- t4 = vec_mergel(src[6], src[7]);
+ vec_xst(t5, 0, vecOffset);
+ vec_xst(t6, 0, vecOffset + 4);
+ vec_xst(t7, 0, vecOffset + 8);
+ vec_xst(t8, 0, vecOffset + 12);
- t5 = vec_xxpermdi(t1, t2, 0);
- t6 = vec_xxpermdi(t3, t4, 0);
- t7 = vec_xxpermdi(t1, t2, 3);
- t8 = vec_xxpermdi(t3, t4, 3);
+ t1 = vec_mergel(src[0], src[1]);
+ t2 = vec_mergel(src[2], src[3]);
+ t3 = vec_mergel(src[4], src[5]);
+ t4 = vec_mergel(src[6], src[7]);
- vec_xst(t5, 0, vecOffset + 16);
- vec_xst(t6, 0, vecOffset + 20);
- vec_xst(t7, 0, vecOffset + 24);
- vec_xst(t8, 0, vecOffset + 28);
+ t5 = vec_xxpermdi(t1, t2, 0);
+ t6 = vec_xxpermdi(t3, t4, 0);
+ t7 = vec_xxpermdi(t1, t2, 3);
+ t8 = vec_xxpermdi(t3, t4, 3);
+
+ vec_xst(t5, 0, vecOffset + 16);
+ vec_xst(t6, 0, vecOffset + 20);
+ vec_xst(t7, 0, vecOffset + 24);
+ vec_xst(t8, 0, vecOffset + 28);
}
- void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) {
+ void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {
int64_t i, j;
float * aoffsets[8];
- float *aoffset = NULL, *boffset = NULL;
+ float * aoffset = NULL, * boffset = NULL;
__vector_pair arr[8];
vector float c[8][2] = {0};
vector float c1[8] = {0};
vector float c2[8] = {0};
- aoffset = const_cast(a);
+ aoffset = const_cast(a);
boffset = vec;
j = (rows >> 3);
if (j > 0) {
-
do {
aoffsets[0] = aoffset;
- for (int it = 1; it< 8; it++)
+ for (int it = 1; it < 8; it++)
aoffsets[it] = aoffsets[it-1] + lda;
aoffset += 8 * lda;
i = (cols >> 3);
if (i > 0) {
do {
- for (int it = 0; it< 8; it++) {
+ for (int it = 0; it < 8; it++) {
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
c1[it] = c[it][0];
@@ -2264,11 +2287,14 @@ class tinyBLAS_PPC {
}
vector_permute_store_8(c1, boffset);
- vector_permute_store_8(c2, boffset+32);
- for (int it = 0; it < 4; it++)
- aoffsets[it] = aoffsets[it] + 8*lda;
+ vector_permute_store_8(c2, boffset + 32);
boffset += 64;
i--;
+ if (i > 0) {
+ for (int it = 0; it < 8; it++) {
+ aoffsets[it] = aoffsets[it] + 8;
+ }
+ }
} while(i > 0);
}
if (cols & 4) {
@@ -2295,9 +2321,9 @@ class tinyBLAS_PPC {
c2[it] = c[it][1];
}
vector_permute_store_4(c1, boffset);
- vector_permute_store_4(c2, boffset+16);
+ vector_permute_store_4(c2, boffset + 16);
for (int it = 0; it < 4; it++)
- aoffsets[it] += 8*lda;
+ aoffsets[it] += 8 * lda;
boffset += 32;
i--;
} while(i > 0);
@@ -2325,15 +2351,15 @@ class tinyBLAS_PPC {
vec_t vec_A[4], vec_B[4], vec_C[4];
acc_t acc_0;
__builtin_mma_xxsetaccz(&acc_0);
- for (int l = 0; l < k; l+=4) {
- packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
- packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+ for (int l = 0; l < k; l += 4) {
+ packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
+ packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
}
- SAVE_ACC(&acc_0, ii, jj);
+ save_acc(&acc_0, ii, jj);
}
void KERNEL_4x8(int64_t ii, int64_t jj) {
@@ -2341,9 +2367,9 @@ class tinyBLAS_PPC {
acc_t acc_0, acc_1;
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
- for (int64_t l = 0; l < k; l+=4) {
- packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
- packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
+ for (int64_t l = 0; l < k; l += 4) {
+ packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
+ packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@@ -2353,8 +2379,8 @@ class tinyBLAS_PPC {
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
__builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
}
- SAVE_ACC(&acc_0, ii, jj);
- SAVE_ACC(&acc_1, ii, jj+4);
+ save_acc(&acc_0, ii, jj);
+ save_acc(&acc_1, ii, jj + 4);
}
void KERNEL_8x4(int64_t ii, int64_t jj) {
@@ -2362,9 +2388,9 @@ class tinyBLAS_PPC {
acc_t acc_0, acc_1;
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
- for (int64_t l = 0; l < k; l+=4) {
- packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
- packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+ for (int64_t l = 0; l < k; l += 4) {
+ packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);
+ packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@@ -2374,8 +2400,8 @@ class tinyBLAS_PPC {
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
}
- SAVE_ACC(&acc_0, ii, jj);
- SAVE_ACC(&acc_1, ii+4, jj);
+ save_acc(&acc_0, ii, jj);
+ save_acc(&acc_1, ii + 4, jj);
}
void KERNEL_8x8(int64_t ii, int64_t jj) {
@@ -2386,19 +2412,96 @@ class tinyBLAS_PPC {
__builtin_mma_xxsetaccz(&acc_2);
__builtin_mma_xxsetaccz(&acc_3);
for (int l = 0; l < k; l+=8) {
- packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
- packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
+ packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);
+ packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);
for(int x = 0; x < 16; x+=2) {
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
- __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
- __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
- __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);
+ __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);
+ __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);
+ }
+ }
+ save_acc(&acc_0, ii, jj);
+ save_acc(&acc_1, ii, jj + 4);
+ save_acc(&acc_2, ii + 4, jj);
+ save_acc(&acc_3, ii + 4, jj + 4);
+ }
+
+ inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {
+ for (int x = 0; x < 16; x += 2) {
+ __builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);
+ __builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);
+ __builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);
+ __builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);
+ __builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);
+ __builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);
+ __builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);
+ __builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);
+ }
+ }
+
+ void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
+ for (int64_t i = 0; i < mc; i += 16) {
+ int A_base_addr = (mc / 8) * (i / 8) * 16;
+ for (int64_t j = 0; j < nc; j += 8) {
+ int B_base_addr = (nc / 8) * (j / 8) * 16;
+ acc_t acc[8];
+ vec_t A0_block[16]; vec_t A1_block[16];
+ for (int x = 0; x < 8; x++)
+ __builtin_mma_xxsetaccz(&acc[x]);
+ for (int64_t l = 0; l < kc; l += 8) {
+ int A0_block_idx = A_base_addr + (l / 8) * 16;
+ int A1_block_idx = A0_block_idx + (mc / 8) * 16;
+ int B_block_idx = B_base_addr + (l / 8) * 16;
+ vec_t* A0_block = &vec_A[A0_block_idx];
+ vec_t* A1_block = &vec_A[A1_block_idx];
+ vec_t* B_block = &vec_B[B_block_idx];
+ MMA_16x8(A0_block, A1_block, B_block, acc);
+ }
+ if (kk == 0) {
+ save_acc(&acc[0], ii + i, jj + j);
+ save_acc(&acc[1], ii + i, jj + j + 4);
+ save_acc(&acc[2], ii + i + 4, jj + j);
+ save_acc(&acc[3], ii + i + 4, jj + j + 4);
+ save_acc(&acc[4], ii + i + 8, jj + j);
+ save_acc(&acc[5], ii + i + 8, jj + j + 4);
+ save_acc(&acc[6], ii + i + 12, jj + j);
+ save_acc(&acc[7], ii + i + 12, jj + j + 4);
+ } else {
+ add_save_acc(&acc[0], ii + i, jj + j);
+ add_save_acc(&acc[1], ii + i, jj + j + 4);
+ add_save_acc(&acc[2], ii + i + 4, jj + j);
+ add_save_acc(&acc[3], ii + i + 4, jj + j + 4);
+ add_save_acc(&acc[4], ii + i + 8, jj + j);
+ add_save_acc(&acc[5], ii + i + 8, jj + j + 4);
+ add_save_acc(&acc[6], ii + i + 12, jj + j);
+ add_save_acc(&acc[7], ii + i + 12, jj + j + 4);
+ }
+ }
+ }
+ }
+
+ void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
+ int64_t ytiles = m / mc;
+ int64_t xtiles = n / nc;
+ int64_t tiles = xtiles * ytiles;
+ int64_t duty = (tiles + nth - 1) / nth;
+ int64_t start = duty * ith;
+ int64_t end = start + duty;
+ if (end > tiles) {
+ end = tiles;
+ }
+ for (int64_t job = start; job < end; ++job) {
+ int64_t ii = (job / xtiles) * mc;
+ int64_t jj = (job % xtiles) * nc;
+ for (int64_t kk = 0; kk < k; kk += kc) {
+ vec_t A_pack[kc * mc / 4];
+ vec_t B_pack[kc * nc / 4];
+ packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);
+ packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);
+ KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);
}
}
- SAVE_ACC(&acc_0, ii, jj);
- SAVE_ACC(&acc_1, ii, jj+4);
- SAVE_ACC(&acc_2, ii+4, jj);
- SAVE_ACC(&acc_3, ii+4, jj+4);
}
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
@@ -2406,35 +2509,35 @@ class tinyBLAS_PPC {
int n_rem = MIN(n - n0, 8);
int mc = 0, nc = 0;
if (m_rem >= 8 && n_rem >= 8) {
- mc = 8;
- nc = 8;
- gemm<8, 8>(m0, m, n0, n);
+ mc = 8;
+ nc = 8;
+ gemm<8, 8>(m0, m, n0, n);
} else if (m_rem >= 4 && n_rem >= 8) {
- mc = 4;
- nc = 8;
- gemm<4, 8>(m0, m, n0, n);
+ mc = 4;
+ nc = 8;
+ gemm<4, 8>(m0, m, n0, n);
} else if (m_rem >= 8 && n_rem >= 4) {
- mc = 8;
- nc = 4;
- gemm<8, 4>(m0, m, n0, n);
+ mc = 8;
+ nc = 4;
+ gemm<8, 4>(m0, m, n0, n);
} else if (m_rem >= 4 && n_rem >= 4) {
- mc = 4;
- nc = 4;
- gemm<4, 4>(m0, m, n0, n);
+ mc = 4;
+ nc = 4;
+ gemm<4, 4>(m0, m, n0, n);
} else {
mc = (m_rem >= 4) ? 4 : m_rem;
nc = (n_rem >= 4) ? 4 : n_rem;
if (mc == 0 || nc == 0)
- return;
+ return;
gemm_small(m0, m, n0, n, mc, nc);
}
int64_t mp = m0 + ((m - m0) / mc) * mc;
int64_t np = n0 + ((n - n0) / nc) * nc;
mnpack(mp, m, n0, np);
mnpack(m0, m, np, n);
- }
+ }
- void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
int64_t ytiles = (m - m0) / RM;
int64_t xtiles = (n - n0) / RN;
int64_t tiles = xtiles * ytiles;
@@ -2449,30 +2552,30 @@ class tinyBLAS_PPC {
vec_t vec_C[4];
acc_t acc_0;
__builtin_mma_xxsetaccz(&acc_0);
- vec_t vec_A[4] {0}, vec_B[4] = {0};
- for (int l=0; l(A+(ii)*lda+l);
- packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
+ float * a = const_cast(A + (ii) * lda + l);
+ packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
vec_A[0] = (vec_t)vec_xl(0,a);
- vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
- vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
- vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
+ vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));
+ vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));
+ vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));
} else if (RN == 1) {
- packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
- float* b = const_cast(B+(jj)*ldb+l);
+ packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
+ float * b = const_cast(B + (jj) * ldb + l);
vec_B[0] = (vec_t)vec_xl(0,b);
- vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1));
- vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2));
- vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3));
+ vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));
+ vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));
+ vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));
} else {
- packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
- packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
+ packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
+ packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
}
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@@ -2482,12 +2585,27 @@ class tinyBLAS_PPC {
__builtin_mma_disassemble_acc(vec_C, &acc_0);
for (int I = 0; I < RM; I++) {
for (int J = 0; J < RN; J++) {
- *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
}
}
}
}
+ template
+ inline void kernel(int64_t ii, int64_t jj) {
+ if constexpr(RM == 4 && RN == 4) {
+ KERNEL_4x4(ii, jj);
+ } else if constexpr(RM == 4 && RN == 8) {
+ KERNEL_4x8(ii, jj);
+ } else if constexpr(RM == 8 && RN == 4) {
+ KERNEL_8x4(ii, jj);
+ } else if constexpr(RM == 8 && RN == 8) {
+ KERNEL_8x8(ii, jj);
+ } else {
+ static_assert(false, "RN/RM values not supported");
+ }
+ }
+
template
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t ytiles = (m - m0) / RM;
@@ -2496,27 +2614,18 @@ class tinyBLAS_PPC {
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
- if (RM == 4 && RN == 4) {
- kernel = &tinyBLAS_PPC::KERNEL_4x4;
- } else if (RM == 4 && RN == 8) {
- kernel = &tinyBLAS_PPC::KERNEL_4x8;
- } else if (RM == 8 && RN == 4) {
- kernel = &tinyBLAS_PPC::KERNEL_8x4;
- } else if (RM == 8 && RN == 8) {
- kernel = &tinyBLAS_PPC::KERNEL_8x8;
- }
if (end > tiles)
end = tiles;
for (int64_t job = start; job < end; ++job) {
int64_t ii = m0 + job / xtiles * RM;
int64_t jj = n0 + job % xtiles * RN;
- (this->*kernel)(ii, jj);
+ kernel(ii, jj);
}
}
- const float *const A;
- const float *const B;
- float *C;
+ const float * const A;
+ const float * const B;
+ float * C;
const int64_t k;
const int64_t lda;
const int64_t ldb;
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 460367cca09e9..8c1f7948855ac 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -9003,8 +9003,7 @@ static void ggml_compute_forward_ssm_scan_f32(
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
- // allows optimizing the modulo since n_group should be a power of 2
- GGML_ASSERT((ng & -ng) == ng);
+ GGML_ASSERT(nh % ng == 0);
// heads per thread
const int dh = (nh + nth - 1)/nth;
@@ -9035,6 +9034,7 @@ static void ggml_compute_forward_ssm_scan_f32(
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
const float dA = expf(dt_soft_plus * A[h]);
+ const int g = h / (nh / ng); // repeat_interleave
// dim
for (int i1 = 0; i1 < nr; ++i1) {
@@ -9057,8 +9057,8 @@ static void ggml_compute_forward_ssm_scan_f32(
// TODO: maybe unroll more?
for (int j = 0; j < 1; j++) {
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
- GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
- GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
+ GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
+ GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
t0 = GGML_F32_VEC_MUL(t0, adA);
t1 = GGML_F32_VEC_MUL(t1, axdt);
@@ -9072,6 +9072,9 @@ static void ggml_compute_forward_ssm_scan_f32(
}
sumf = GGML_F32xt_REDUCE_ONE(sum);
+ #elif defined(__riscv_v_intrinsic)
+ // todo: RVV implementation
+ const int np = 0;
#else
const int np = (nc & ~(GGML_F32_STEP - 1));
@@ -9087,8 +9090,8 @@ static void ggml_compute_forward_ssm_scan_f32(
for (int i = 0; i < np; i += GGML_F32_STEP) {
for (int j = 0; j < GGML_F32_ARR; j++) {
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
- ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
- az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
+ ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
+ az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
@@ -9110,7 +9113,7 @@ static void ggml_compute_forward_ssm_scan_f32(
// d_state
for (int i0 = np; i0 < nc; ++i0) {
const int i = i0 + ii*nc;
- const int ig = i0 + (h & (ng - 1))*nc;
+ const int ig = i0 + g*nc;
// state = prev_state * dA + dB * x
const float state = (s0[i] * dA) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
@@ -9127,6 +9130,7 @@ static void ggml_compute_forward_ssm_scan_f32(
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
+ const int g = h / (nh / ng); // repeat_interleave
// dim
for (int i1 = 0; i1 < nr; ++i1) {
@@ -9141,8 +9145,8 @@ static void ggml_compute_forward_ssm_scan_f32(
// TODO: what happens when (d_state % svcntw()) != 0?
for (int64_t k = 0; k < nc; k += svcntw()) {
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
- svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
- svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
+ svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
+ svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
@@ -9162,7 +9166,7 @@ static void ggml_compute_forward_ssm_scan_f32(
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
const int i = i0 + ii*nc;
- const int ig = i0 + (h & (ng - 1))*nc;
+ const int ig = i0 + g*nc;
// state = prev_state * dA + dB * x
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
@@ -10023,8 +10027,8 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
int64_t h_stride_2d = head_size * head_size;
#if defined(GGML_SIMD)
- #if defined(__ARM_FEATURE_SVE)
- // scalar Route to scalar implementation //TODO: Write SVE code
+ #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
+ // scalar Route to scalar implementation //TODO: Write SVE code and RVV code
for (int64_t t = 0; t < T; t++) {
int64_t t_offset = t * t_stride;
int64_t state_offset = head_size * C * (t / (T / n_seqs));
diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h
index b4ad68c9fd647..8bd56bdac1b43 100644
--- a/ggml/src/ggml-cpu/simd-mappings.h
+++ b/ggml/src/ggml-cpu/simd-mappings.h
@@ -18,6 +18,10 @@
#include
#endif
+#if defined(__riscv_v_intrinsic)
+#include
+#endif
+
#ifdef __cplusplus
extern "C" {
#endif
@@ -94,24 +98,15 @@ extern "C" {
}
#elif defined(__riscv) && defined(__riscv_zfhmin)
static inline float riscv_compute_fp16_to_fp32(ggml_fp16_t h) {
- float f;
- __asm__(
- "fmv.h.x %[f], %[h]\n\t"
- "fcvt.s.h %[f], %[f]"
- : [f] "=&f" (f)
- : [h] "r" (h)
- );
- return f;
+ _Float16 hf;
+ memcpy(&hf, &h, sizeof(ggml_fp16_t));
+ return hf;
}
static inline ggml_fp16_t riscv_compute_fp32_to_fp16(float f) {
ggml_fp16_t res;
- __asm__(
- "fcvt.h.s %[f], %[f]\n\t"
- "fmv.x.h %[h], %[f]"
- : [h] "=&r" (res)
- : [f] "f" (f)
- );
+ _Float16 hf = (_Float16)f;
+ memcpy(&res, &hf, sizeof(ggml_fp16_t));
return res;
}
@@ -220,6 +215,47 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
#define GGML_F32_VEC_MUL GGML_F32xt_MUL
#define GGML_F32_VEC_REDUCE GGML_F32xt_REDUCE
+// F16 SVE
+#define DEFAULT_PG32 svptrue_b32()
+#define DEFAULT_PG16 svptrue_b16()
+
+#define GGML_F32Cxt svfloat16_t
+#define GGML_F32Cxt_ZERO svdup_n_f16(0.0f)
+#define GGML_F32Cxt_SET1(x) svdup_n_f16(x)
+#define GGML_F32Cxt_LOAD(p) svld1_f16(DEFAULT_PG16, (const __fp16 *)(p))
+#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec))
+
+#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a)
+#define GGML_F32Cxt_FMA(...) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__)
+#define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b)
+#define GGML_F32Cxt_ADD(...) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__)
+#define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b)
+#define GGML_F32Cxt_MUL(...) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__)
+#define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED
+
+#define GGML_F16x_VEC GGML_F32Cxt
+#define GGML_F16x_VEC_ZERO GGML_F32Cxt_ZERO
+#define GGML_F16x_VEC_SET1 GGML_F32Cxt_SET1
+#define GGML_F16x_VEC_LOAD(p, i) GGML_F32Cxt_LOAD(p)
+#define GGML_F16x_VEC_STORE(p, r, i) GGML_F32Cxt_STORE((__fp16 *)(p), r)
+#define GGML_F16x_VEC_FMA GGML_F32Cxt_FMA
+#define GGML_F16x_VEC_ADD GGML_F32Cxt_ADD
+#define GGML_F16x_VEC_MUL GGML_F32Cxt_MUL
+#define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE
+
+#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a)
+#define GGML_F16xt_REDUCE_ONE(...) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__)
+
+#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \
+{ \
+ sum1 = svadd_f16_x(pg16, sum1, sum2); \
+ sum3 = svadd_f16_x(pg16, sum3, sum4); \
+ sum1 = svadd_f16_x(pg16, sum1, sum3); \
+ __fp16 sum_f16 = svaddv_f16(pg16, sum1); \
+ (res) = (ggml_float) sum_f16; \
+}
+#define GGML_F16xt_REDUCE_MIXED(...) GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__)
+
// F16 NEON
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
@@ -1170,6 +1206,36 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
+#elif defined(__riscv_v_intrinsic)
+
+// compatible with vlen >= 128
+
+#define GGML_SIMD
+
+// F32
+
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR 4
+
+#define GGML_F32x4 vfloat32m1_t
+#define GGML_F32x4_ZERO __riscv_vfmv_v_f_f32m1(0.0f, GGML_F32_EPR)
+#define GGML_F32x4_SET1(x) __riscv_vfmv_v_f_f32m1(x, GGML_F32_EPR)
+#define GGML_F32x4_LOAD(x) __riscv_vle32_v_f32m1(x, GGML_F32_EPR)
+#define GGML_F32x4_STORE(b, v) __riscv_vse32_v_f32m1(b, v, GGML_F32_EPR)
+#define GGML_F32x4_FMA(a, b, c) __riscv_vfmacc_vv_f32m1(a, b, c, GGML_F32_EPR)
+#define GGML_F32x4_ADD(a, b) __riscv_vfadd_vv_f32m1(a, b, GGML_F32_EPR)
+#define GGML_F32x4_MUL(a, b) __riscv_vfmul_vv_f32m1(a, b, GGML_F32_EPR)
+
+#define GGML_F32_VEC GGML_F32x4
+#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
#endif
// GGML_F32_ARR / GGML_F16_ARR
diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp
index 07b377bdd82a7..437192d525a34 100644
--- a/ggml/src/ggml-cpu/vec.cpp
+++ b/ggml/src/ggml-cpu/vec.cpp
@@ -84,6 +84,22 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
}
// reduce sum1,sum2 to sum1
GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
+ #elif defined(__riscv_v_intrinsic)
+ int vl = __riscv_vsetvlmax_e32m8();
+ vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1);
+ vfloat32m8_t vsum;
+ vfloat32m8_t ax;
+ vfloat32m8_t ay;
+ vsum = __riscv_vfmv_v_f_f32m8_tu(vsum, 0.0f, vl);
+ for (int i = 0; i < n; i += vl) {
+ vl = __riscv_vsetvl_e32m8(n - i);
+ ax = __riscv_vle32_v_f32m8_tu(ax, &x[i], vl);
+ ay = __riscv_vle32_v_f32m8_tu(ay, &y[i], vl);
+ vsum = __riscv_vfmacc_vv_f32m8_tu(vsum, ax, ay, vl);
+ }
+ vl = __riscv_vsetvlmax_e32m8();
+ vs = __riscv_vfredusum_vs_f32m8_f32m1(vsum, vs, vl);
+ sumf += __riscv_vfmv_f_s_f32m1_f32(vs);
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@@ -197,38 +213,125 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
ggml_float sumf = 0.0;
-#if defined(GGML_SIMD)
- const int np = (n & ~(GGML_F16_STEP - 1));
- GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
+#if defined(GGML_SIMD)
+ #if defined(__ARM_FEATURE_SVE)
+ const int sve_register_length = svcntb() * 8; //get vector length
+ const int ggml_f16_epr = sve_register_length / 16; // running when 16
+ const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
+
+ const int np= (n & ~(ggml_f16_step - 1));
+ svfloat16_t sum1 = svdup_n_f16(0.0f);
+ svfloat16_t sum2 = svdup_n_f16(0.0f);
+ svfloat16_t sum3 = svdup_n_f16(0.0f);
+ svfloat16_t sum4 = svdup_n_f16(0.0f);
+
+ svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
+ svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
+ for (int i = 0; i < np; i += ggml_f16_step) {
+ ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
+ ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
+ sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1);
+
+ ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
+ ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
+ sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2);
+
+ ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
+ ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
+ sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3);
+
+ ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
+ ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
+ sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4);
+
+ ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
+ ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
+ sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5);
+
+ ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
+ ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
+ sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6);
+
+ ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
+ ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
+ sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7);
+
+ ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
+ ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
+ sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8);
+ }
- GGML_F16_VEC ax[GGML_F16_ARR];
- GGML_F16_VEC ay[GGML_F16_ARR];
+ const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8
+ for (int k = np; k < np2; k += ggml_f16_epr) {
+ svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
+ svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
+ sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry);
+ }
- for (int i = 0; i < np; i += GGML_F16_STEP) {
- for (int j = 0; j < GGML_F16_ARR; j++) {
- ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+ if (np2 < n) {
+ svbool_t pg = svwhilelt_b16(np2, n);
+ svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
+ svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
- sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
+ sum1 = svmad_f16_x(pg, hx, hy, sum1);
}
- }
+ GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4);
+ #elif defined(__riscv_v_intrinsic)
+ #if defined(__riscv_zvfh)
+ int vl = __riscv_vsetvlmax_e32m2();
+ vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1);
+ vfloat32m2_t vsum;
+ vfloat16m1_t ax;
+ vfloat16m1_t ay;
+ vsum = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vmv_v_x_u32m2(0, vl));
+ for (int i = 0; i < n; i += vl) {
+ vl = __riscv_vsetvl_e16m1(n - i);
+ ax = __riscv_vle16_v_f16m1_tu(ax, (const _Float16 *)&x[i], vl);
+ ay = __riscv_vle16_v_f16m1_tu(ay, (const _Float16 *)&y[i], vl);
+ vsum = __riscv_vfwmacc_vv_f32m2_tu(vsum, ax, ay, vl);
+ }
+ vl = __riscv_vsetvlmax_e32m1();
+ vfloat32m1_t ac0 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(vsum, 0), __riscv_vget_v_f32m2_f32m1(vsum, 1), vl);
+ vs = __riscv_vfredusum_vs_f32m1_f32m1(ac0, vs, vl);
+ sumf += __riscv_vfmv_f_s_f32m1_f32(vs);
+ #else
+ for (int i = 0; i < n; ++i) {
+ sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
+ }
+ #endif // __riscv_zvfh
+ #else
+ const int np = (n & ~(GGML_F16_STEP - 1));
- // reduce sum0..sum3 to sum0
- GGML_F16_VEC_REDUCE(sumf, sum);
+ GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
- // leftovers
- for (int i = np; i < n; ++i) {
- sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
- }
+ GGML_F16_VEC ax[GGML_F16_ARR];
+ GGML_F16_VEC ay[GGML_F16_ARR];
+
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
+ for (int j = 0; j < GGML_F16_ARR; j++) {
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
- // if you hit this, you are likely running outside the FP range
- assert(!isnan(sumf) && !isinf(sumf));
+ sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
+ }
+ }
+
+ // reduce sum0..sum3 to sum0
+ GGML_F16_VEC_REDUCE(sumf, sum);
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
+ }
+ // if you hit this, you are likely running outside the FP range
+ assert(!isnan(sumf) && !isinf(sumf));
+ #endif
#else
for (int i = 0; i < n; ++i) {
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
-#endif
+#endif // GGML_SIMD
*s = sumf;
}
@@ -247,6 +350,12 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
for (; i + 3 < n; i += 4) {
_mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
}
+#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
+ const int vlen = svcntw();
+ for (; i < n; i += vlen) {
+ const svbool_t pg = svwhilelt_b32_s32(i, n);
+ svst1_f32(pg, y + i, ggml_v_silu(pg, svld1_f32(pg, x + i)));
+ }
#elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) {
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
@@ -271,10 +380,24 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
for (; i + 3 < n; i += 4) {
_mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
}
+#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
+ const int vlen = svcntw();
+ for (; i < n; i += vlen) {
+ const svbool_t pg = svwhilelt_b32_s32(i, n);
+ svst1_f32(pg, y + i, svmul_f32_x(pg, ggml_v_silu(pg, svld1_f32(pg, x + i)), svld1_f32(pg, g + i)));
+ }
#elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) {
vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
}
+#elif defined(__riscv_v_intrinsic)
+ for (int vl; i < n; i += vl) {
+ vl = __riscv_vsetvl_e32m2(n - i);
+ vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
+ vfloat32m2_t vg = __riscv_vle32_v_f32m2(&g[i], vl);
+ vfloat32m2_t vy = __riscv_vfmul_vv_f32m2(ggml_v_silu_m2(vx, vl), vg, vl);
+ __riscv_vse32_v_f32m2(&y[i], vy, vl);
+ }
#endif
for (; i < n; ++i) {
y[i] = ggml_silu_f32(x[i]) * g[i];
@@ -318,6 +441,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float
#endif
sum += (ggml_float)_mm_cvtss_f32(val);
}
+#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
+ const int vlen = svcntw();
+ for (; i < n; i += vlen) {
+ const svbool_t pg = svwhilelt_b32_s32(i, n);
+ svfloat32_t val = ggml_v_expf(pg, svsub_f32_x(pg, svld1_f32(pg, x + i),
+ svdup_n_f32_x(pg, max)));
+ svst1_f32(pg, y + i, val);
+ sum += (ggml_float)svaddv_f32(pg, val);
+ }
#elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) {
float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
@@ -325,6 +457,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float
vst1q_f32(y + i, val);
sum += (ggml_float)vaddvq_f32(val);
}
+#elif defined(__riscv_v_intrinsic)
+ vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
+ for (int avl; i < n; i += avl) {
+ avl = __riscv_vsetvl_e32m2(n - i);
+ vfloat32m2_t val = ggml_v_expf_m2(__riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl), avl);
+ __riscv_vse32_v_f32m2(&y[i], val, avl);
+ vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, avl);
+ }
+ return (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
#endif
for (; i < n; ++i) {
float val = expf(x[i] - max);
diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h
index 2250d93cb00d1..ef334d089d1f7 100644
--- a/ggml/src/ggml-cpu/vec.h
+++ b/ggml/src/ggml-cpu/vec.h
@@ -119,36 +119,149 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
}
#if defined(GGML_SIMD)
- const int np = (n & ~(GGML_F16_STEP - 1));
+ #if defined(__ARM_FEATURE_SVE)
+
+ const int sve_register_length = svcntb() * 8;
+ const int ggml_f16_epr = sve_register_length / 16; // running when 16
+ const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
+
+ const int np = (n & ~(ggml_f16_step - 1));
+
+ svfloat16_t sum_00 = svdup_n_f16(0.0f);
+ svfloat16_t sum_01 = svdup_n_f16(0.0f);
+ svfloat16_t sum_02 = svdup_n_f16(0.0f);
+ svfloat16_t sum_03 = svdup_n_f16(0.0f);
+
+ svfloat16_t sum_10 = svdup_n_f16(0.0f);
+ svfloat16_t sum_11 = svdup_n_f16(0.0f);
+ svfloat16_t sum_12 = svdup_n_f16(0.0f);
+ svfloat16_t sum_13 = svdup_n_f16(0.0f);
+
+ svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
+ svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
+
+ for (int i = 0; i < np; i += ggml_f16_step) {
+ ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements
+
+ ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elemnst
+ sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1
+ ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements
+ sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1);
+
+ ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements
+
+ ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 ekements
+ sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
+ ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1);
+ sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
+
+ ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
+
+ ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2);
+ sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
+ ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
+ sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3);
+
+ ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
+
+ ax4 = GGML_F16x_VEC_LOAD(x[0] + i + 3*ggml_f16_epr, 3);
+ sum_03 = GGML_F16x_VEC_FMA(sum_03, ax4, ay4);
+ ax4 = GGML_F16x_VEC_LOAD(x[1] + i + 3*ggml_f16_epr, 3);
+ sum_13 = GGML_F16x_VEC_FMA(sum_13, ax4, ay4);
+
+ ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
+
+ ax5 = GGML_F16x_VEC_LOAD(x[0] + i + 4*ggml_f16_epr, 4);
+
+ sum_00 = GGML_F16x_VEC_FMA(sum_00, ax5, ay5);
+ ax5 = GGML_F16x_VEC_LOAD(x[1] + i + 4*ggml_f16_epr, 4);
+ sum_10 = GGML_F16x_VEC_FMA(sum_10, ax5, ay5);
+
+ ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
+
+ ax6 = GGML_F16x_VEC_LOAD(x[0] + i + 5*ggml_f16_epr, 5);
+
+ sum_01 = GGML_F16x_VEC_FMA(sum_01, ax6, ay6);
+ ax6 = GGML_F16x_VEC_LOAD(x[1] + i + 5*ggml_f16_epr, 5);
+ sum_11 = GGML_F16x_VEC_FMA(sum_11, ax6, ay6);
+
+ ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
+
+ ax7 = GGML_F16x_VEC_LOAD(x[0] + i + 6*ggml_f16_epr, 6);
+
+ sum_02 = GGML_F16x_VEC_FMA(sum_02, ax7, ay7);
+ ax7 = GGML_F16x_VEC_LOAD(x[1] + i + 6*ggml_f16_epr, 6);
+ sum_12 = GGML_F16x_VEC_FMA(sum_12, ax7, ay7);
+
+ ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
+
+ ax8 = GGML_F16x_VEC_LOAD(x[0] + i + 7*ggml_f16_epr, 7);
+
+ sum_03 = GGML_F16x_VEC_FMA(sum_03, ax8, ay8);
+ ax8 = GGML_F16x_VEC_LOAD(x[1] + i + 7*ggml_f16_epr, 7);
+ sum_13 = GGML_F16x_VEC_FMA(sum_13, ax8, ay8);
+ }
+
+ const int np2 = (n & ~(ggml_f16_epr - 1));
+ for (int k = np; k < np2; k += ggml_f16_epr) {
+ svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
+
+ svfloat16_t rx = GGML_F16x_VEC_LOAD(x[0] + k, 0);
+ sum_00 = GGML_F16x_VEC_FMA(sum_00, rx, ry);
+ rx = GGML_F16x_VEC_LOAD(x[1] + k, 0);
+ sum_10 = GGML_F16x_VEC_FMA(sum_10, rx, ry);
+ }
+
+ if (np2 < n) {
+ svbool_t pg = svwhilelt_b16(np2, n);
+ svfloat16_t hx_0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2));
+ svfloat16_t hx_1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2));
+ svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
+
+ sum_00 = svmad_f16_x(pg, hx_0, hy, sum_00);
+ sum_10 = svmad_f16_x(pg, hx_1, hy, sum_10);
+ }
+ GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
+ GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
+ #elif defined(__riscv_v_intrinsic)
+ // todo: RVV impl
+ for (int i = 0; i < n; ++i) {
+ for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+ sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
+ }
+ }
+ #else
+ const int np = (n & ~(GGML_F16_STEP - 1));
- GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
+ GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
- GGML_F16_VEC ax[GGML_F16_ARR];
- GGML_F16_VEC ay[GGML_F16_ARR];
+ GGML_F16_VEC ax[GGML_F16_ARR];
+ GGML_F16_VEC ay[GGML_F16_ARR];
- for (int i = 0; i < np; i += GGML_F16_STEP) {
- for (int j = 0; j < GGML_F16_ARR; j++) {
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
+ for (int j = 0; j < GGML_F16_ARR; j++) {
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
- for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
- ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
+ for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+ ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
- sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
+ sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
+ }
}
}
- }
- // reduce sum0..sum3 to sum0
- for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
- GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
- }
+ // reduce sum0..sum3 to sum0
+ for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+ GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
+ }
- // leftovers
- for (int i = np; i < n; ++i) {
- for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
- sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+ sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
+ }
}
- }
+ #endif
#else
for (int i = 0; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
@@ -243,6 +356,14 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
svst1_f32(pg, y + np2, ay1);
}
+ #elif defined(__riscv_v_intrinsic)
+ for (int i = 0, avl; i < n; i += avl) {
+ avl = __riscv_vsetvl_e32m8(n - i);
+ vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
+ vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
+ vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, v, ay, avl);
+ __riscv_vse32_v_f32m8(&y[i], ny, avl);
+ }
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@@ -276,27 +397,112 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
#if defined(GGML_SIMD)
- const int np = (n & ~(GGML_F16_STEP - 1));
+ #if defined(__ARM_FEATURE_SVE)
+ const int sve_register_length = svcntb() * 8;
+ const int ggml_f16_epr = sve_register_length / 16;
+ const int ggml_f16_step = 8 * ggml_f16_epr;
+
+ GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
+
+ const int np= (n & ~(ggml_f16_step - 1));
+
+ svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
+ svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
+ for (int i = 0; i < np; i += ggml_f16_step) {
+ ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
+ ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
+ ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
+
+ GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
+
+ ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
+ ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
+ ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
+
+ GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
+
+ ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
+ ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
+ ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
+
+ GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
+
+ ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
+ ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
+ ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
+
+ GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
+
+ ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
+ ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
+ ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
+
+ GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
+
+ ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
+ ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
+ ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
+
+ GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
- GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+ ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
+ ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
+ ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
- GGML_F16_VEC ax[GGML_F16_ARR];
- GGML_F16_VEC ay[GGML_F16_ARR];
+ GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
- for (int i = 0; i < np; i += GGML_F16_STEP) {
- for (int j = 0; j < GGML_F16_ARR; j++) {
- ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
- ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
+ ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
+ ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
+ ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
- GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+ GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
}
- }
+ const int np2 = (n & ~(ggml_f16_epr - 1));
+ for (int k = np; k < np2; k += ggml_f16_epr) {
+ svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
+ svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
+ ry = GGML_F16x_VEC_FMA(ry, rx, vx);
- // leftovers
- for (int i = np; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
- }
+ GGML_F16x_VEC_STORE(y + k, ry, 0);
+ }
+
+ if (np2 < n) {
+ svbool_t pg = svwhilelt_b16(np2, n);
+ svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
+ svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
+ hy = svmad_f16_x(pg, hx, vx, hy);
+ svst1_f16(pg, (__fp16 *)(y + np2), hy);
+ }
+
+ #elif defined(__riscv_v_intrinsic)
+ // todo: RVV impl
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
+ }
+ #else
+ const int np = (n & ~(GGML_F16_STEP - 1));
+
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+ GGML_F16_VEC ax[GGML_F16_ARR];
+ GGML_F16_VEC ay[GGML_F16_ARR];
+
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
+ for (int j = 0; j < GGML_F16_ARR; j++) {
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+ ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
+
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+ }
+ }
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
+ }
+ #endif
#else
// scalar
for (int i = 0; i < n; ++i) {
@@ -324,6 +530,16 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
y[i] += x[k][i]*v[k][0];
}
}
+ #elif defined(__riscv_v_intrinsic)
+ for (int i = 0, avl; i < n; i += avl) {
+ avl = __riscv_vsetvl_e32m8(n - i);
+ vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; k++) {
+ vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[k][i], avl);
+ ay = __riscv_vfmadd_vf_f32m8(ax, v[k][0], ay, avl);
+ }
+ __riscv_vse32_v_f32m8(&y[i], ay, avl);
+ }
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@@ -375,6 +591,14 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, co
for (int i = 0; i < n; ++i) {
y[i] = x[i]*s + b;
}
+ #elif defined(__riscv_v_intrinsic)
+ for (int i = 0, avl; i < n; i += avl) {
+ avl = __riscv_vsetvl_e32m8(n - i);
+ vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
+ vfloat32m8_t vb = __riscv_vfmv_v_f_f32m8(b, avl);
+ vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, s, vb, avl);
+ __riscv_vse32_v_f32m8(&y[i], ny, avl);
+ }
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@@ -436,6 +660,13 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
ay1 = svmul_f32_m(pg, ay1, vx);
svst1_f32(pg, y + np, ay1);
}
+ #elif defined(__riscv_v_intrinsic)
+ for (int i = 0, avl; i < n; i += avl) {
+ avl = __riscv_vsetvl_e32m8(n - i);
+ vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
+ vfloat32m8_t ny = __riscv_vfmul_vf_f32m8(ay, v, avl);
+ __riscv_vse32_v_f32m8(&y[i], ny, avl);
+ }
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@@ -467,25 +698,59 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
#if defined(GGML_SIMD)
- const int np = (n & ~(GGML_F16_STEP - 1));
+ #if defined(__ARM_FEATURE_SVE)
+ const int sve_register_length = svcntb() * 8;
+ const int ggml_f16_epr = sve_register_length / 16;
+ const int ggml_f16_step = 2 * ggml_f16_epr;
+
+ GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
+ const int np = (n & ~(ggml_f16_step - 1));
+ svfloat16_t ay1, ay2;
+
+ for (int i = 0; i < np; i += ggml_f16_step) {
+ ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);
+ ay1 = GGML_F16x_VEC_MUL(ay1, vx);
+ GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);
+
+ ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);
+ ay2 = GGML_F16x_VEC_MUL(ay2, vx);
+ GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);
+ }
+ // leftovers
+ // maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
+ if (np < n) {
+ svbool_t pg = svwhilelt_b16(np, n);
+ svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
+ svfloat16_t out = svmul_f16_m(pg, hy, vx);
+ svst1_f16(pg, (__fp16 *)(y + np), out);
+ }
+ #elif defined(__riscv_v_intrinsic)
+ // todo: RVV impl
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
+ }
+ #else
+ const int np = (n & ~(GGML_F16_STEP - 1));
- GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
- GGML_F16_VEC ay[GGML_F16_ARR];
+ GGML_F16_VEC ay[GGML_F16_ARR];
- for (int i = 0; i < np; i += GGML_F16_STEP) {
- for (int j = 0; j < GGML_F16_ARR; j++) {
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
- ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
+ for (int j = 0; j < GGML_F16_ARR; j++) {
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+ ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
- GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+ }
}
- }
- // leftovers
- for (int i = np; i < n; ++i) {
- y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
- }
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
+ }
+ #endif
#else
// scalar
for (int i = 0; i < n; ++i) {
@@ -737,7 +1002,39 @@ inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {
}
#endif
-#if defined(__ARM_NEON) && defined(__aarch64__)
+#if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
+
+inline static svfloat32_t ggml_v_expf(svbool_t pg, svfloat32_t x) {
+ const svfloat32_t r = svdup_n_f32_x(pg, 0x1.8p23f);
+ const svfloat32_t z = svmla_n_f32_x(pg, r, x, 0x1.715476p+0f);
+ const svfloat32_t n = svsub_f32_x(pg, z, r);
+ const svfloat32_t b = svmls_n_f32_x(pg, svmls_n_f32_x(pg, x, n, 0x1.62e4p-1f), n, 0x1.7f7d1cp-20f);
+ const svuint32_t e = svlsl_n_u32_x(pg, svreinterpret_u32_f32(z), 23);
+ const svfloat32_t k = svreinterpret_f32_u32(svadd_u32_x(pg, e, svreinterpret_u32_f32(svdup_n_f32_x(pg, 1))));
+ const svbool_t c = svacgt_n_f32(pg, n, 126);
+ const svfloat32_t u = svmul_f32_x(pg, b, b);
+ const svfloat32_t j = svmla_f32_x(pg,
+ svmul_n_f32_x(pg, b, 0x1.ffffecp-1f),
+ svmla_f32_x(pg, svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.fffdb6p-2f), svdup_n_f32_x(pg, 0x1.555e66p-3f), b),
+ svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.573e2ep-5f), svdup_n_f32_x(pg, 0x1.0e4020p-7f), b), u), u);
+ const svuint32_t d = svdup_n_u32_z(svcmple_n_f32(pg, n, 0.0), 0x82000000);
+ const svfloat32_t s1 = svreinterpret_f32_u32(svadd_n_u32_x(pg, d, 0x7f000000));
+ const svfloat32_t s2 = svreinterpret_f32_u32(svsub_u32_x(pg, e, d));
+ return svsel_f32(svacgt_f32(pg, n, svdup_n_f32_x(pg, 192)), svmul_f32_x(pg, s1, s1),
+ svsel_f32(c, svmul_f32_x(pg, svmla_f32_x(pg, s2, s2, j), s1), svmla_f32_x(pg, k, k, j)));
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) {
+ const svfloat32_t one = svdup_n_f32_x(pg, 1.0f);
+ const svfloat32_t zero = svdup_n_f32_x(pg, 0.0f);
+ const svfloat32_t neg_x = svsub_f32_x(pg, zero, x);
+ const svfloat32_t exp_neg_x = ggml_v_expf(pg, neg_x);
+ const svfloat32_t one_plus_exp_neg_x = svadd_f32_x(pg, one, exp_neg_x);
+ return svdiv_f32_x(pg, x, one_plus_exp_neg_x);
+}
+
+#elif defined(__ARM_NEON) && defined(__aarch64__)
// adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps
@@ -928,7 +1225,59 @@ inline static __m128 ggml_v_silu(__m128 x) {
return _mm_div_ps(x, one_plus_exp_neg_x);
}
-#endif // __ARM_NEON / __AVX2__ / __SSE2__
+#elif defined(__riscv_v_intrinsic)
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) {
+ const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl);
+#ifdef __riscv_xtheadvector
+ // workaround for compiler bug (gcc 14.3.0: Error: unrecognized opcode `th.vmv1r.v v2,v4')
+ vfloat32m2_t z = __riscv_vfadd_vf_f32m2(r, 0.0f, vl);
+ z = __riscv_vfmacc_vf_f32m2(z, 0x1.715476p+0f, x, vl);
+#else
+ const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl);
+#endif
+ const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl);
+ const vfloat32m2_t b = __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl),
+ 0x1.7f7d1cp-20f, n, vl);
+ const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl);
+ const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); // 1.0f
+ const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl);
+ const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl);
+ const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2(
+ __riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl),
+ __riscv_vfmacc_vv_f32m2(
+ __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl),
+ __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl),
+ u, vl), u, vl);
+ if (!__riscv_vcpop_m_b16(c, vl))
+ return __riscv_vfmacc_vv_f32m2(k, j, k, vl);
+ const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl);
+ const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl);
+ const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl));
+ const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl));
+ const vfloat32m2_t r1 = __riscv_vmerge_vvm_f32m2(
+ __riscv_vfmacc_vv_f32m2(k, k, j, vl),
+ __riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl),
+ c, vl);
+ return __riscv_vmerge_vvm_f32m2(
+ r1, __riscv_vfmul_vv_f32m2(s1, s1, vl),
+ __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl),
+ vl);
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) {
+ const vfloat32m2_t neg_x = __riscv_vfneg_v_f32m2(x, vl);
+ const vfloat32m2_t exp_neg_x = ggml_v_expf_m2(neg_x, vl);
+ const vfloat32m2_t one_plus_exp_neg_x = __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl);
+ return __riscv_vfdiv_vv_f32m2(x, one_plus_exp_neg_x, vl);
+}
+
+#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic
inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
for (int i = 0; i < n; ++i) {
diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt
index ea824965aae2d..d3dfc7807de0e 100644
--- a/ggml/src/ggml-cuda/CMakeLists.txt
+++ b/ggml/src/ggml-cuda/CMakeLists.txt
@@ -94,7 +94,11 @@ if (CUDAToolkit_FOUND)
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
else ()
- target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static)
+ if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1")
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
+ else()
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static)
+ endif()
endif()
else()
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu
index e1fbf0e13665d..1c76566344a88 100644
--- a/ggml/src/ggml-cuda/binbcast.cu
+++ b/ggml/src/ggml-cuda/binbcast.cu
@@ -1,5 +1,6 @@
#include "binbcast.cuh"
#include
+#include
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
@@ -22,13 +23,16 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
return a / b;
}
-template
+
+
+template
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
- int ne0, int ne1, int ne2, int ne3,
- int ne10, int ne11, int ne12, int ne13,
- /*int s0, */ int s1, int s2, int s3,
- /*int s00,*/ int s01, int s02, int s03,
- /*int s10,*/ int s11, int s12, int s13) {
+ const int ne0, const int ne1, const int ne2, const int ne3,
+ const int ne10, const int ne11, const int ne12, const int ne13,
+ /*int s0, */ const int s1, const int s2, const int s3,
+ /*int s00,*/ const int s01, const int s02, const int s03,
+ /*int s10,*/ const int s11, const int s12, const int s13,
+ src1_ptrs... src1s) {
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
@@ -46,24 +50,31 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
- const src0_t * src0_row = src0 + i_src0;
- const src1_t * src1_row = src1 + i_src1;
+ const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
const int i10 = i0 % ne10;
- dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+
+ float result = src0_row ? (float) src0_row[i0] : 0.0f;
+ if constexpr (sizeof...(src1_ptrs) > 0) {
+ result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+ } else {
+ result = bin_op(result, (float)src1[i_src1 + i10]);
+ }
+
+ dst_row[i0] = (dst_t) result;
}
}
-template
-static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
- int ne0, int ne1, int ne2, int ne3,
- int ne10, int ne11, int ne12, int ne13,
- /*int s0, */ int s1, int s2, int s3,
- /*int s00,*/ int s01, int s02, int s03,
- /*int s10,*/ int s11, int s12, int s13) {
-
+template
+static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
+ const int ne0, const int ne1, const int ne2,const int ne3,
+ const int ne10, const int ne11, const int ne12, const int ne13,
+ /*int s0, */ const int s1, const int s2, const int s3,
+ /*int s00,*/ const int s01, const int s02, const int s03,
+ /*int s10,*/ const int s11, const int s12, const int s13,
+ src1_ptrs ... src1s) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
const int i3 = i/(ne2*ne1*ne0);
@@ -83,12 +94,190 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
- const src0_t * src0_row = src0 + i_src0;
- const src1_t * src1_row = src1 + i_src1;
+ const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
const int i10 = i0 % ne10;
- dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+
+ float result = src0_row ? (float) src0_row[i0] : 0.0f;
+ if constexpr (sizeof...(src1_ptrs) > 0) {
+ result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+ } else {
+ result = bin_op(result, (float)src1[i_src1 + i10]);
+ }
+
+ dst_row[i0] = (dst_t) result;
+}
+
+template
+static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
+ cudaStream_t stream, std::index_sequence) {
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ int nr0 = ne10 / ne0;
+ int nr1 = ne11 / ne1;
+ int nr2 = ne12 / ne2;
+ int nr3 = ne13 / ne3;
+
+ int nr[4] = { nr0, nr1, nr2, nr3 };
+
+ int64_t cne[] = { ne0, ne1, ne2, ne3 };
+ int64_t cne0[] = { ne00, ne01, ne02, ne03 };
+ int64_t cne1[] = { ne10, ne11, ne12, ne13 };
+
+ size_t cnb[] = { nb0, nb1, nb2, nb3 };
+ size_t cnb0[] = { nb00, nb01, nb02, nb03 };
+ size_t cnb1[] = { nb10, nb11, nb12, nb13 };
+
+ auto collapse = [](int64_t cne[]) {
+ cne[0] *= cne[1];
+ cne[1] = cne[2];
+ cne[2] = cne[3];
+ cne[3] = 1;
+ };
+
+ auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
+ cnb[1] *= cne[1];
+ cnb[2] *= cne[2];
+ cnb[3] *= cne[3];
+ };
+
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
+ for (int i = 0; i < 4; i++) {
+ if (nr[i] != 1) {
+ break;
+ }
+ if (i > 0) {
+ collapse_nb(cnb, cne);
+ collapse_nb(cnb0, cne0);
+ collapse_nb(cnb1, cne1);
+ collapse(cne);
+ collapse(cne0);
+ collapse(cne1);
+ }
+ }
+ }
+
+ {
+ int64_t ne0 = cne[0];
+ int64_t ne1 = cne[1];
+ int64_t ne2 = cne[2];
+ int64_t ne3 = cne[3];
+
+ //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
+ //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
+ //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
+ //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
+
+ int64_t ne10 = cne1[0];
+ int64_t ne11 = cne1[1];
+ int64_t ne12 = cne1[2];
+ int64_t ne13 = cne1[3];
+
+ size_t nb0 = cnb[0];
+ size_t nb1 = cnb[1];
+ size_t nb2 = cnb[2];
+ size_t nb3 = cnb[3];
+
+ size_t nb00 = cnb0[0];
+ size_t nb01 = cnb0[1];
+ size_t nb02 = cnb0[2];
+ size_t nb03 = cnb0[3];
+
+ size_t nb10 = cnb1[0];
+ size_t nb11 = cnb1[1];
+ size_t nb12 = cnb1[2];
+ size_t nb13 = cnb1[3];
+
+ size_t s0 = nb0 / sizeof(dst_t);
+ size_t s1 = nb1 / sizeof(dst_t);
+ size_t s2 = nb2 / sizeof(dst_t);
+ size_t s3 = nb3 / sizeof(dst_t);
+
+ size_t s10 = nb10 / sizeof(src1_t);
+ size_t s11 = nb11 / sizeof(src1_t);
+ size_t s12 = nb12 / sizeof(src1_t);
+ size_t s13 = nb13 / sizeof(src1_t);
+
+ size_t s00 = nb00 / sizeof(src0_t);
+ size_t s01 = nb01 / sizeof(src0_t);
+ size_t s02 = nb02 / sizeof(src0_t);
+ size_t s03 = nb03 / sizeof(src0_t);
+
+ GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
+
+ GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
+
+ GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
+
+ GGML_ASSERT(s0 == 1);
+ GGML_ASSERT(s00 == 1);
+ GGML_ASSERT(s10 == 1);
+
+ const int block_size = 128;
+
+ int64_t hne0 = std::max(ne0 / 2LL, 1LL);
+
+ dim3 block_dims;
+ block_dims.x = std::min(hne0, block_size);
+ block_dims.y = std::min(ne1, block_size / block_dims.x);
+ block_dims.z = std::min(std::min(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
+
+ dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x,
+ (ne1 + block_dims.y - 1) / block_dims.y,
+ (ne2 * ne3 + block_dims.z - 1) / block_dims.z);
+
+ if (block_nums.z > 65535) {
+ int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
+ if constexpr (sizeof...(I) > 0) {
+ k_bin_bcast_unravel
+ <<>>(src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12,s13,
+ (const src1_t *) dst->src[I + 1]->data...);
+ } else {
+ k_bin_bcast_unravel
+ <<>>(src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12,s13);
+ }
+ } else {
+ if constexpr (sizeof...(I) > 0) {
+ k_bin_bcast
+ <<>>(src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12,s13,
+ (const src1_t *) dst->src[I + 1]->data...);
+ } else {
+ k_bin_bcast
+ <<>>(src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00,*/ s01, s02, s03,
+ /* s10,*/ s11, s12,s13);
+ }
+ }
+ }
}
template
@@ -120,160 +309,14 @@ static __global__ void k_repeat_back(
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
}
-template
+template
struct bin_bcast_cuda {
template
void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
cudaStream_t stream) {
-
- GGML_TENSOR_BINARY_OP_LOCALS
-
- int nr0 = ne10/ne0;
- int nr1 = ne11/ne1;
- int nr2 = ne12/ne2;
- int nr3 = ne13/ne3;
-
- int nr[4] = { nr0, nr1, nr2, nr3 };
-
- // collapse dimensions until first broadcast dimension
- int64_t cne[] = {ne0, ne1, ne2, ne3};
- int64_t cne0[] = {ne00, ne01, ne02, ne03};
- int64_t cne1[] = {ne10, ne11, ne12, ne13};
-
- size_t cnb[] = {nb0, nb1, nb2, nb3};
- size_t cnb0[] = {nb00, nb01, nb02, nb03};
- size_t cnb1[] = {nb10, nb11, nb12, nb13};
-
- auto collapse = [](int64_t cne[]) {
- cne[0] *= cne[1];
- cne[1] = cne[2];
- cne[2] = cne[3];
- cne[3] = 1;
- };
-
- auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
- cnb[1] *= cne[1];
- cnb[2] *= cne[2];
- cnb[3] *= cne[3];
- };
-
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
- for (int i = 0; i < 4; i++) {
- if (nr[i] != 1) {
- break;
- }
- if (i > 0) {
- collapse_nb(cnb, cne);
- collapse_nb(cnb0, cne0);
- collapse_nb(cnb1, cne1);
- collapse(cne);
- collapse(cne0);
- collapse(cne1);
- }
- }
- }
-
- {
- int64_t ne0 = cne[0];
- int64_t ne1 = cne[1];
- int64_t ne2 = cne[2];
- int64_t ne3 = cne[3];
-
- //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
- //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
- //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
- //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
-
- int64_t ne10 = cne1[0];
- int64_t ne11 = cne1[1];
- int64_t ne12 = cne1[2];
- int64_t ne13 = cne1[3];
-
- size_t nb0 = cnb[0];
- size_t nb1 = cnb[1];
- size_t nb2 = cnb[2];
- size_t nb3 = cnb[3];
-
- size_t nb00 = cnb0[0];
- size_t nb01 = cnb0[1];
- size_t nb02 = cnb0[2];
- size_t nb03 = cnb0[3];
-
- size_t nb10 = cnb1[0];
- size_t nb11 = cnb1[1];
- size_t nb12 = cnb1[2];
- size_t nb13 = cnb1[3];
-
- size_t s0 = nb0 / sizeof(dst_t);
- size_t s1 = nb1 / sizeof(dst_t);
- size_t s2 = nb2 / sizeof(dst_t);
- size_t s3 = nb3 / sizeof(dst_t);
-
- size_t s10 = nb10 / sizeof(src1_t);
- size_t s11 = nb11 / sizeof(src1_t);
- size_t s12 = nb12 / sizeof(src1_t);
- size_t s13 = nb13 / sizeof(src1_t);
-
- size_t s00 = nb00 / sizeof(src0_t);
- size_t s01 = nb01 / sizeof(src0_t);
- size_t s02 = nb02 / sizeof(src0_t);
- size_t s03 = nb03 / sizeof(src0_t);
-
- GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
- GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
- GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
- GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
-
- GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
- GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
- GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
- GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
-
- GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
- GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
- GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
- GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
-
- GGML_ASSERT(s0 == 1);
- GGML_ASSERT(s00 == 1);
- GGML_ASSERT(s10 == 1);
-
- const int block_size = 128;
-
- int64_t hne0 = std::max(ne0/2LL, 1LL);
-
- dim3 block_dims;
- block_dims.x = std::min(hne0, block_size);
- block_dims.y = std::min(ne1, block_size / block_dims.x);
- block_dims.z = std::min(std::min(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U);
-
- dim3 block_nums(
- (hne0 + block_dims.x - 1) / block_dims.x,
- (ne1 + block_dims.y - 1) / block_dims.y,
- (ne2*ne3 + block_dims.z - 1) / block_dims.z
- );
-
- if (block_nums.z > 65535) {
- // this is the maximum number of blocks in z dimension, fallback to 1D grid kernel
- int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
- k_bin_bcast_unravel<<>>(
- src0_dd, src1_dd, dst_dd,
- ne0, ne1, ne2, ne3,
- ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00, */ s01, s02, s03,
- /* s10, */ s11, s12, s13);
- } else {
- k_bin_bcast<<>>(
- src0_dd, src1_dd, dst_dd,
- ne0, ne1, ne2, ne3,
- ne10, ne11, ne12, ne13,
- /* s0, */ s1, s2, s3,
- /* s00, */ s01, s02, s03,
- /* s10, */ s11, s12, s13);
- }
- }
+ launch_bin_bcast_pack(
+ src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence{});
}
};
@@ -312,7 +355,7 @@ static void ggml_cuda_op_bin_bcast(
}
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- ggml_cuda_op_bin_bcast>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
+ ggml_cuda_op_bin_bcast>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
}
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -331,6 +374,68 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}
+template
+static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ cudaStream_t stream = ctx.stream();
+
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ launch_bin_bcast_pack(src0, src1, dst,
+ (const float *) src0->data, (const float *) src1->data, (float *) dst->data,
+ stream, std::make_index_sequence{});
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+ launch_bin_bcast_pack(src0, src1, dst,
+ (const half *) src0->data, (const half *) src1->data, (half *) dst->data,
+ stream, std::make_index_sequence{});
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+ launch_bin_bcast_pack(src0, src1, dst,
+ (const half *) src0->data, (const float *) src1->data, (half *) dst->data,
+ stream, std::make_index_sequence{});
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+ launch_bin_bcast_pack(src0, src1, dst,
+ (const half *) src0->data, (const float *) src1->data, (float *) dst->data,
+ stream, std::make_index_sequence{});
+ } else {
+ fprintf(stderr,
+ "%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n",
+ __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
+ GGML_ABORT("fatal error");
+ }
+}
+
+
+void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
+ GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
+
+ switch (n_fuse) {
+ case 2:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 3:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 4:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 5:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 6:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 7:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ case 8:
+ ggml_cuda_op_fused_binbcast_impl(ctx, dst);
+ break;
+ default:
+ GGML_ASSERT(false && "Unsupported n_fuse value");
+ }
+}
+
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
diff --git a/ggml/src/ggml-cuda/binbcast.cuh b/ggml/src/ggml-cuda/binbcast.cuh
index 3ac1c9b03fcea..62bc950111b70 100644
--- a/ggml/src/ggml-cuda/binbcast.cuh
+++ b/ggml/src/ggml-cuda/binbcast.cuh
@@ -7,3 +7,5 @@ void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 767ad83f60eb5..85bc9e933bca5 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -107,9 +107,9 @@ constexpr bool ggml_cuda_has_arch(const int arch) {
return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
}
-constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) {
+constexpr int ggml_cuda_highest_compiled_arch_impl(const int /*arch*/, const int cur) {
if (cur == 0) {
- GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
+ return -1;
}
return cur;
}
@@ -420,16 +420,28 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
template
static __device__ __forceinline__ int warp_reduce_all(int x) {
-#ifdef GGML_USE_HIP
+ if (width == ggml_cuda_get_physical_warp_size()) {
+ return __all_sync(0xffffffff, x);
+ } else {
#pragma unroll
- for (int offset = width/2; offset > 0; offset >>= 1) {
- x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
+ for (int offset = width/2; offset > 0; offset >>= 1) {
+ x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
+ }
+ return x;
+ }
+}
+
+template
+static __device__ __forceinline__ int warp_reduce_any(int x) {
+ if (width == ggml_cuda_get_physical_warp_size()) {
+ return __any_sync(0xffffffff, x);
+ } else {
+#pragma unroll
+ for (int offset = width/2; offset > 0; offset >>= 1) {
+ x = __shfl_xor_sync(0xffffffff, x, offset, width) || x;
+ }
+ return x;
}
- return x;
-#else
- static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
- return __all_sync(0xffffffff, x);
-#endif // GGML_USE_HIP
}
template
diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu
new file mode 100644
index 0000000000000..142dd66903aaa
--- /dev/null
+++ b/ggml/src/ggml-cuda/conv2d.cu
@@ -0,0 +1,166 @@
+#include "conv2d.cuh"
+#include "convert.cuh"
+
+struct conv_params {
+ const int64_t IW, IH;
+ const int64_t OW, OH;
+ const int64_t KW, KH;
+ const int64_t ST_X, ST_Y;
+ const int64_t PD_X, PD_Y;
+ const int64_t DL_X, DL_Y;
+ const int64_t IC, OC;
+ const int64_t B;
+ const int64_t TOTAL;
+};
+
+struct kernel_bounds {
+ int64_t y_min, y_max;
+ int64_t x_min, x_max;
+};
+
+__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) {
+ return (a > b) ? a : b;
+}
+
+__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) {
+ return (a < b) ? a : b;
+}
+
+__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) {
+ kernel_bounds bounds;
+ bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
+ bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
+ bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
+ bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
+ return bounds;
+}
+
+__device__ __forceinline__ int calculate_input_coord(int64_t out_coord,
+ int64_t kern_coord,
+ int64_t stride,
+ int64_t dilation,
+ int64_t padding) {
+ return out_coord * stride + kern_coord * dilation - padding;
+}
+
+struct whcn_layout {
+ __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
+ return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;
+ }
+
+ __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) {
+ return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;
+ }
+
+ __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
+ return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;
+ }
+
+ __device__ static void unpack_indices(int64_t global_idx,
+ const conv_params & P,
+ int64_t & n,
+ int64_t & c,
+ int64_t & out_y,
+ int64_t & out_x) {
+ out_x = global_idx % P.OW;
+ out_y = (global_idx / P.OW) % P.OH;
+ c = (global_idx / (P.OW * P.OH)) % P.OC;
+ n = global_idx / (P.OW * P.OH * P.OC);
+ }
+};
+
+template
+static __global__ void conv2d_kernel(const float * __restrict__ input,
+ const T * __restrict__ kernel,
+ float * __restrict__ output,
+ const conv_params P) {
+ const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (global_idx >= P.TOTAL) {
+ return;
+ }
+
+ int64_t n, c_out, out_y, out_x;
+ Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
+
+ float acc = 0.0f;
+
+ for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
+ kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
+
+ for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {
+ const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y);
+
+ for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
+ const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
+
+ const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
+ const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
+ acc += (input_val * ggml_cuda_cast(kernel_val));
+ }
+ }
+ }
+
+ // [N, OC, OH, OW]
+ output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc;
+}
+
+template
+static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
+ const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
+ conv2d_kernel<<>>(X_D, K_D, Y_D, P);
+}
+
+static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
+ conv2d_cuda(X_D, K_D, Y_D, P, st);
+}
+
+static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
+ conv2d_cuda(X_D, K_D, Y_D, P, st);
+}
+
+void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * kernel = dst->src[0];
+ const ggml_tensor * input = dst->src[1];
+ float * K_D = (float *) kernel->data;
+ const float * X_D = (const float *) input->data;
+ float * Y_D = (float *) dst->data;
+
+ GGML_ASSERT(ggml_is_contiguous(kernel));
+ GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);
+
+ // same number of input channels
+ GGML_ASSERT(input->ne[2] == kernel->ne[2]);
+
+ cudaStream_t st = ctx.stream();
+
+ const int32_t * p = (const int32_t *) dst->op_params;
+ const int ST_X = p[0]; // stride_x
+ const int ST_Y = p[1]; // stride_y
+ const int PD_X = p[2]; // padding_x
+ const int PD_Y = p[3]; // padding_y
+ const int DL_X = p[4]; // dilation_x
+ const int DL_Y = p[5]; // dilation_y
+
+ // No cwhn
+ GGML_ASSERT(p[6] == false);
+
+ const int IW = input->ne[0]; // input_w
+ const int IH = input->ne[1]; // input_h
+ const int OW = dst->ne[0]; // output_w
+ const int OH = dst->ne[1]; // output_h
+ const int KW = kernel->ne[0]; // kernel_w
+ const int KH = kernel->ne[1]; // kernel_h
+ const int IC = input->ne[2]; // input_channels
+ const int OC = kernel->ne[3]; // ouptut_chanles
+ const int B = input->ne[3]; // n_batches
+
+ const int64_t total = B * OC * OH * OW;
+ conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
+
+ if (kernel->type == GGML_TYPE_F16) {
+ conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st);
+ } else {
+ conv2d_cuda_f32(X_D, K_D, Y_D, params, st);
+ }
+}
diff --git a/ggml/src/ggml-cuda/conv2d.cuh b/ggml/src/ggml-cuda/conv2d.cuh
new file mode 100644
index 0000000000000..ce4802c7ed797
--- /dev/null
+++ b/ggml/src/ggml-cuda/conv2d.cuh
@@ -0,0 +1,5 @@
+#pragma once
+#include "common.cuh"
+
+#define CUDA_CONV2D_BLOCK_SIZE 256
+void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index aa45ab39ed89e..e06f95f0819ed 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -12,6 +12,7 @@
#include "ggml-cuda/clamp.cuh"
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
+#include "ggml-cuda/conv2d.cuh"
#include "ggml-cuda/conv2d-dw.cuh"
#include "ggml-cuda/conv2d-transpose.cuh"
#include "ggml-cuda/convert.cuh"
@@ -204,6 +205,8 @@ static ggml_cuda_device_info ggml_cuda_init() {
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
#endif // GGML_CUDA_FORCE_CUBLAS
GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
+
+ std::vector> turing_devices_without_mma;
for (int id = 0; id < info.device_count; ++id) {
int device_vmm = 0;
@@ -261,7 +264,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].cc = 100*prop.major + 10*prop.minor;
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
-#endif // defined(GGML_USE_HIP)
+ std::string device_name(prop.name);
+ if (device_name == "NVIDIA GeForce MX450") {
+ turing_devices_without_mma.push_back({ id, device_name });
+ } else if (device_name == "NVIDIA GeForce MX550") {
+ turing_devices_without_mma.push_back({ id, device_name });
+ } else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
+ turing_devices_without_mma.push_back({ id, device_name });
+ }
+#endif // defined(GGML_USE_HIP)
+ }
+
+ if (ggml_cuda_highest_compiled_arch(GGML_CUDA_CC_TURING) >= GGML_CUDA_CC_TURING && !turing_devices_without_mma.empty()) {
+ GGML_LOG_INFO("The following devices will have suboptimal performance due to a lack of tensor cores:\n");
+ for (size_t device_pos = 0; device_pos < turing_devices_without_mma.size(); device_pos++) {
+ GGML_LOG_INFO(
+ " Device %d: %s\n", turing_devices_without_mma[device_pos].first, turing_devices_without_mma[device_pos].second.c_str());
+ }
+ GGML_LOG_INFO(
+ "Consider compiling with CMAKE_CUDA_ARCHITECTURES=61-virtual;80-virtual and DGGML_CUDA_FORCE_MMQ to force the use of the Pascal code for Turing.\n");
}
for (int id = 0; id < info.device_count; ++id) {
@@ -2431,6 +2452,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst);
break;
+ case GGML_OP_CONV_2D:
+ ggml_cuda_op_conv2d(ctx, dst);
+ break;
case GGML_OP_CONV_2D_DW:
ggml_cuda_op_conv2d_dw(ctx, dst);
break;
@@ -2797,9 +2821,14 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}
- if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
+ if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
+ const ggml_tensor *add = nullptr;
+
+ if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
+ add = cgraph->nodes[node_idx+2];
+ }
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
@@ -2811,6 +2840,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}
+ if (add && (add->src[0]->type != GGML_TYPE_F32 ||
+ add->src[1]->type != GGML_TYPE_F32 ||
+ add->type != GGML_TYPE_F32) ) {
+ return false;
+ }
+
//if rms norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
return false;
@@ -2821,6 +2856,10 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}
+ if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) {
+ return false;
+ }
+
return true;
}
@@ -2867,7 +2906,46 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
- if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
+
+ if (node->op == GGML_OP_ADD) {
+ int n_fuse = 0;
+ ggml_op ops[8];
+ std::fill(ops, ops + 8, GGML_OP_ADD);
+
+ for (; n_fuse <= 6; ++n_fuse){
+ if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
+ break;
+ }
+ if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
+ break;
+ }
+ if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
+ break;
+ }
+ }
+
+ n_fuse++;
+
+ if (n_fuse > 1) {
+ for (int j = 0; j < n_fuse - 1; ++j) {
+ node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
+ }
+ cgraph->nodes[i + n_fuse - 1]->data = node->data;
+ ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
+ i += n_fuse - 1;
+
+ continue;
+ }
+ }
+
+
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
+ ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
+ i += 2;
+ continue;
+ }
+
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
@@ -3086,7 +3164,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
return false;
}
-#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
+#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) || defined(GGML_USE_HIP)
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
if (err != cudaSuccess) {
// clear the error
@@ -3481,6 +3559,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
}
case GGML_OP_IM2COL:
+ case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
index 576032a0ce0dd..714b23f9f49aa 100644
--- a/ggml/src/ggml-cuda/mmq.cu
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -3,6 +3,140 @@
#include
+// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
+struct mmq_ids_helper_store {
+ uint32_t data;
+
+ __device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
+ data = (it & 0x003FFFFF) | (iex_used << 22);
+ }
+
+ __device__ uint32_t it() const {
+ return data & 0x003FFFFF;
+ }
+
+ __device__ uint32_t iex_used() const {
+ return data >> 22;
+ }
+};
+static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
+
+// Helper function for mul_mat_id, converts ids to a more convenient format.
+// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
+// ids_dst describes the same mapping but for the dst tensor.
+// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
+template
+__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
+static __global__ void mmq_ids_helper(
+ const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
+ const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+ const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
+ const int expert = blockIdx.x;
+
+ extern __shared__ char data_mmq_ids_helper[];
+ mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
+
+ int nex_prev = 0; // Number of columns for experts with a lower index.
+ int it_compact = 0; // Running index for the compact slice of this expert.
+
+ if constexpr (n_expert_used_template == 0) {
+ // Generic implementation:
+ for (int it = 0; it < n_tokens; ++it) {
+ int iex_used = -1; // The index at which the expert is used, if any.
+ for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
+ const int expert_used = ids[it*si1 + iex];
+ nex_prev += expert_used < expert;
+ if (expert_used == expert) {
+ iex_used = iex;
+ }
+ }
+
+ if (iex_used != -1) {
+ store[it_compact] = mmq_ids_helper_store(it, iex_used);
+ }
+
+ if (warp_reduce_any(iex_used != -1)) {
+ it_compact++;
+ }
+ }
+ } else {
+ // Implementation optimized for specific numbers of experts used:
+ static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
+ const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
+ for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
+ const int it = it0 + threadIdx.x / neu_padded;
+
+ const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
+ const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
+ ids[it*si1 + iex] : INT_MAX;
+ const int iex_used = expert_used == expert ? iex : -1;
+ nex_prev += expert_used < expert;
+
+ // Whether the threads at this token position have used the expert:
+ const int it_compact_add_self = warp_reduce_any(iex_used != -1);
+
+ // Do a scan over threads at lower token positions in warp to get the correct index for writing data:
+ int it_compact_add_lower = 0;
+#pragma unroll
+ for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
+ const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
+ if (threadIdx.x >= offset) {
+ it_compact_add_lower += tmp;
+ }
+ }
+
+ if (iex_used != -1) {
+ store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
+ }
+
+ // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
+ it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
+ }
+ }
+ nex_prev = warp_reduce_sum(nex_prev);
+
+ for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
+ const mmq_ids_helper_store store_it = store[itc];
+ const int it = store_it.it();
+ const int iex_used = store_it.iex_used();
+ ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
+ ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
+ }
+
+ if (threadIdx.x != 0) {
+ return;
+ }
+
+ expert_bounds[expert] = nex_prev;
+
+ if (expert < gridDim.x - 1) {
+ return;
+ }
+
+ expert_bounds[gridDim.x] = nex_prev + it_compact;
+}
+
+template
+static void launch_mmq_ids_helper(
+ const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
+ const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
+ GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
+ GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
+
+ const int id = ggml_cuda_get_device();
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+ CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper, smpbo);
+
+ const dim3 num_blocks(n_experts, 1, 1);
+ const dim3 block_size(warp_size, 1, 1);
+ const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
+ GGML_ASSERT(nbytes_shared <= smpbo);
+ mmq_ids_helper<<>>
+ (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
+}
+
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
switch (args.type_x) {
case GGML_TYPE_Q4_0:
@@ -137,7 +271,7 @@ void ggml_cuda_mul_mat_q(
ne00, ne01, ne1, s01, ne11, s1,
ne02, ne12, s02, s12, s2,
ne03, ne13, s03, s13, s3,
- use_stream_k};
+ use_stream_k, ne1};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
return;
}
@@ -148,53 +282,49 @@ void ggml_cuda_mul_mat_q(
const int64_t n_expert_used = ids->ne[0];
const int64_t ne_get_rows = ne12 * n_expert_used;
+ GGML_ASSERT(ne1 == n_expert_used);
- std::vector ids_host(ggml_nbytes(ids));
- std::vector ids_src1_host;
- ids_src1_host.reserve(ne_get_rows);
- std::vector ids_dst_host;
- ids_dst_host.reserve(ne_get_rows);
- std::vector tokens_per_expert_host(ne02);
- std::vector expert_bounds_host(ne02 + 1);
- ggml_cuda_pool_alloc ids_buf_dev(ctx.pool());
-
- CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
- CUDA_CHECK(cudaStreamSynchronize(stream));
-
- for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
- for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
- for (int64_t iex = 0; iex < n_expert_used; ++iex) {
- const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
- assert(expert_to_use >= 0 && expert_to_use < ne02);
- if (expert_to_use == i02) {
- ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11);
- ids_dst_host.push_back(i12*ne1 + iex);
- tokens_per_expert_host[i02]++;
- break;
- }
- }
- }
- }
+ ggml_cuda_pool_alloc ids_src1(ctx.pool(), ne_get_rows);
+ ggml_cuda_pool_alloc ids_dst(ctx.pool(), ne_get_rows);
+ ggml_cuda_pool_alloc expert_bounds(ctx.pool(), ne02 + 1);
- int32_t cumsum = 0;
- for (int64_t i = 0; i < ne02; ++i) {
- expert_bounds_host[i] = cumsum;
- cumsum += tokens_per_expert_host[i];
+ {
+ GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
+ const int si1 = ids->nb[1] / ggml_element_size(ids);
+ const int sis1 = nb12 / nb11;
+
+ switch (n_expert_used) {
+ case 2:
+ launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ case 4:
+ launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ case 6:
+ launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ case 8:
+ launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ case 16:
+ launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ case 32:
+ launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ default:
+ launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ break;
+ }
+ CUDA_CHECK(cudaGetLastError());
}
- expert_bounds_host[ne02] = cumsum;
-
- std::vector ids_buf_host;
- ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size());
- ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end());
- ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end());
- ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end());
- ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device.
- CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
- CUDA_CHECK(cudaStreamSynchronize(stream));
-
- const int32_t * ids_src1_dev = ids_buf_dev.ptr;
- const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size();
- const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size();
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
@@ -208,7 +338,7 @@ void ggml_cuda_mul_mat_q(
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[2] / ts_src1;
- quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
+ quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
CUDA_CHECK(cudaGetLastError());
}
@@ -218,11 +348,11 @@ void ggml_cuda_mul_mat_q(
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
const mmq_args args = {
- src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
+ src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
ne02, ne02, s02, s12, s2,
ne03, ne13, s03, s13, s3,
- use_stream_k};
+ use_stream_k, ne12};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
}
@@ -262,7 +392,7 @@ void ggml_cuda_op_mul_mat_q(
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
1, 1, 0, 0, 0,
1, 1, 0, 0, 0,
- use_stream_k};
+ use_stream_k, src1_ncols};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index 650f7080677ad..c9a07e82fedf2 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -3138,7 +3138,8 @@ static __global__ void mul_mat_q(
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
- const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+ const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+ const int ncols_max) {
// Skip unused template specializations for faster compilation:
if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
@@ -3152,7 +3153,7 @@ static __global__ void mul_mat_q(
constexpr int qk = ggml_cuda_type_traits::qk;
constexpr int mmq_y = get_mmq_y_device();
- const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x
+ const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
// Initialize the ids for writing back data with just the index.
@@ -3376,7 +3377,8 @@ template
static __global__ void mul_mat_q_stream_k_fixup(
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
- const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
+ const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
+ const int ncols_max) {
constexpr int mmq_y = get_mmq_y_device();
constexpr int qk = ggml_cuda_type_traits::qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
@@ -3387,7 +3389,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
- const int ntx = (ncols_dst + mmq_x - 1) / mmq_x;
+ const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
const int bidx0 = blockIdx.x;
@@ -3528,7 +3530,7 @@ struct mmq_args {
int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
- bool use_stream_k;
+ bool use_stream_k; int64_t ncols_max;
};
template
@@ -3558,7 +3560,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared);
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
- const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
+ const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x;
const int ntzw = args.nchannels_y * args.nsamples_y;
const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
@@ -3574,14 +3576,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+ args.ncols_max);
} else {
constexpr bool need_check = true;
mul_mat_q<<>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+ args.ncols_max);
}
return;
}
@@ -3601,7 +3605,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+ args.ncols_max);
if (!fixup_needed) {
return;
@@ -3609,14 +3614,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
mul_mat_q_stream_k_fixup<<>>
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
- args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
+ args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
+ args.ncols_max);
} else {
constexpr bool need_check = true;
mul_mat_q<<>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+ args.ncols_max);
if (!fixup_needed) {
return;
@@ -3624,7 +3631,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
mul_mat_q_stream_k_fixup<<>>
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
- args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
+ args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
+ args.ncols_max);
}
}
@@ -3649,7 +3657,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
continue;
}
- const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x;
+ const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
if (ntiles_x < ntiles_x_best) {
mmq_x_best = mmq_x;
diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu
index bddcca51b7bfc..d5157d958b717 100644
--- a/ggml/src/ggml-cuda/norm.cu
+++ b/ggml/src/ggml-cuda/norm.cu
@@ -104,12 +104,30 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
}
}
-template
-static __global__ void rms_norm_f32(
- const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
- const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0,
- const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0,
- const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) {
+template
+static __global__ void rms_norm_f32(const float * x, float * dst,
+ const int ncols,
+ const int64_t stride_row,
+ const int64_t stride_channel,
+ const int64_t stride_sample,
+ const float eps,
+ const float * mul = nullptr,
+ const int64_t mul_stride_row = 0,
+ const int64_t mul_stride_channel = 0,
+ const int64_t mul_stride_sample = 0,
+ const int mul_ncols = 0,
+ const int mul_nrows = 0,
+ const int mul_nchannels = 0,
+ const int mul_nsamples = 0,
+ const float * add = nullptr,
+ const int64_t add_stride_row = 0,
+ const int64_t add_stride_channel = 0,
+ const int64_t add_stride_sample = 0,
+ const int add_ncols = 0,
+ const int add_nrows = 0,
+ const int add_nchannels = 0,
+ const int add_nsamples = 0) {
+
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
@@ -118,6 +136,8 @@ static __global__ void rms_norm_f32(
const int sample = blockIdx.z;
const int tid = threadIdx.x;
+ static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
+
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
@@ -128,6 +148,13 @@ static __global__ void rms_norm_f32(
mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
}
+ if constexpr (do_add) {
+ const int add_row = row % add_nrows;
+ const int add_channel = channel % add_nchannels;
+ const int add_sample = sample % add_nsamples;
+ add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
+ }
+
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
@@ -154,7 +181,11 @@ static __global__ void rms_norm_f32(
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
- if constexpr (do_multiply) {
+ if constexpr (do_multiply && do_add) {
+ const int mul_col = col % mul_ncols;
+ const int add_col = col % add_ncols;
+ dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
+ } else if constexpr (do_multiply) {
const int mul_col = col % mul_ncols;
dst[col] = scale * x[col] * mul[mul_col];
} else {
@@ -331,23 +362,70 @@ static void rms_norm_f32_cuda(
}
}
-static void rms_norm_mul_f32_cuda(
- const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
- const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
- const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
- const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
- const float eps, cudaStream_t stream) {
+static void rms_norm_mul_f32_cuda(const float * x,
+ const float * mul,
+ const float * add,
+ float * dst,
+ const int ncols,
+ const int nrows,
+ const int nchannels,
+ const int nsamples,
+ const int64_t stride_row,
+ const int64_t stride_channel,
+ const int64_t stride_sample,
+ const int64_t mul_stride_row,
+ const int64_t mul_stride_channel,
+ const int64_t mul_stride_sample,
+ const int mul_ncols,
+ const int mul_nrows,
+ const int mul_nchannels,
+ const int mul_nsamples,
+ const int64_t add_stride_row,
+ const int64_t add_stride_channel,
+ const int64_t add_stride_sample,
+ const int add_ncols,
+ const int add_nrows,
+ const int add_nchannels,
+ const int add_nsamples,
+ const float eps,
+ cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (mul == nullptr) {
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
return;
}
- if (ncols < 1024) {
- const dim3 block_dims(WARP_SIZE, 1, 1);
- rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ if (add == nullptr) {
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ rms_norm_f32<<>>(x, dst,
+ ncols, stride_row, stride_channel, stride_sample, eps,
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ rms_norm_f32<1024, true><<>>(x, dst,
+ ncols, stride_row, stride_channel, stride_sample, eps,
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ }
} else {
- const dim3 block_dims(1024, 1, 1);
- rms_norm_f32<1024, true><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ rms_norm_f32<<>>(x, dst,
+ ncols, stride_row, stride_channel, stride_sample, eps,
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ add, add_stride_row, add_stride_channel, add_stride_sample,
+ add_ncols, add_nrows, add_nchannels, add_nsamples);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ rms_norm_f32<1024, true, true><<>>(x, dst,
+ ncols, stride_row, stride_channel, stride_sample, eps,
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ add, add_stride_row, add_stride_channel, add_stride_sample,
+ add_ncols, add_nrows, add_nchannels, add_nsamples);
+ }
}
}
@@ -491,7 +569,102 @@ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor *
const int mul_nchannels = mul_src->ne[2];
const int mul_nsamples = mul_src->ne[3];
- rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
+ rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d,
+ ne00, ne01, ne02, ne03,
+ /*s00*/ s01, s02, s03,
+ /*mul_s00*/ mul_s01, mul_s02, mul_s03,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ /*add_s00*/ 0, 0, 0,
+ 0, 0, 0, 0,
+ eps, stream);
+}
+
+void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
+ ggml_tensor * dst,
+ ggml_tensor * mul_tensor,
+ ggml_tensor * add_tensor) {
+ const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
+ float eps = 0.0f;
+
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ const float * src0_d = (const float *) rms_norm_src->data;
+ const float * mul_d = nullptr;
+ const ggml_tensor * mul_src = nullptr;
+
+ if (mul_tensor->src[0] == dst) {
+ mul_d = (float *) mul_tensor->src[1]->data;
+ mul_src = mul_tensor->src[1];
+ } else if (mul_tensor->src[1] == dst) {
+ mul_d = (float *) mul_tensor->src[0]->data;
+ mul_src = mul_tensor->src[0];
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ const float * add_d = nullptr;
+ const ggml_tensor * add_src = nullptr;
+
+ if (add_tensor->src[0] == mul_tensor) {
+ add_d = (float *) add_tensor->src[1]->data;
+ add_src = add_tensor->src[1];
+ } else if (add_tensor->src[1] == mul_tensor) {
+ add_d = (float *) add_tensor->src[0]->data;
+ add_src = add_tensor->src[0];
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ float * dst_d = (float *) add_tensor->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
+ GGML_ASSERT(add_tensor->type == GGML_TYPE_F32);
+ GGML_ASSERT(eps >= 0.0f);
+
+ const int64_t ne00 = rms_norm_src->ne[0];
+ const int64_t ne01 = rms_norm_src->ne[1];
+ const int64_t ne02 = rms_norm_src->ne[2];
+ const int64_t ne03 = rms_norm_src->ne[3];
+
+ const size_t ts0 = ggml_type_size(rms_norm_src->type);
+ GGML_ASSERT(rms_norm_src->nb[0] == ts0);
+ const int64_t s01 = rms_norm_src->nb[1] / ts0;
+ const int64_t s02 = rms_norm_src->nb[2] / ts0;
+ const int64_t s03 = rms_norm_src->nb[3] / ts0;
+
+ const size_t ts_mul = ggml_type_size(mul_src->type);
+ GGML_ASSERT(mul_src->nb[0] == ts_mul);
+ const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
+ const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
+ const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
+
+ const int mul_ncols = mul_src->ne[0];
+ const int mul_nrows = mul_src->ne[1];
+ const int mul_nchannels = mul_src->ne[2];
+ const int mul_nsamples = mul_src->ne[3];
+
+ const size_t ts_add = ggml_type_size(add_src->type);
+ GGML_ASSERT(add_src->nb[0] == ts_add);
+ const int64_t add_s01 = add_src->nb[1] / ts_add;
+ const int64_t add_s02 = add_src->nb[2] / ts_add;
+ const int64_t add_s03 = add_src->nb[3] / ts_add;
+
+ const int add_ncols = add_src->ne[0];
+ const int add_nrows = add_src->ne[1];
+ const int add_nchannels = add_src->ne[2];
+ const int add_nsamples = add_src->ne[3];
+
+ rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d,
+ ne00,ne01, ne02, ne03,
+ /*s00*/ s01, s02, s03,
+ /*mul_s00*/ mul_s01, mul_s02, mul_s03,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ /*add_s00*/ add_s01, add_s02, add_s03,
+ add_ncols, add_nrows, add_nchannels, add_nsamples,
+ eps, stream);
}
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh
index 7ea7bd4df3cc6..a74f6376720ab 100644
--- a/ggml/src/ggml-cuda/norm.cuh
+++ b/ggml/src/ggml-cuda/norm.cuh
@@ -8,6 +8,11 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);
+void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
+ ggml_tensor * dst,
+ ggml_tensor * mul_tensor,
+ ggml_tensor * add_tensor);
+
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu
index dc9a7d58d057c..6b424381df5a7 100644
--- a/ggml/src/ggml-cuda/ssm-scan.cu
+++ b/ggml/src/ggml-cuda/ssm-scan.cu
@@ -129,7 +129,7 @@ __global__ void __launch_bounds__(d_state, 1)
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
const int seq_idx = blockIdx.y;
- const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
+ const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh
index d60292b83b106..6baab1176ffe1 100644
--- a/ggml/src/ggml-cuda/vecdotq.cuh
+++ b/ggml/src/ggml-cuda/vecdotq.cuh
@@ -28,7 +28,58 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
return ((const int *) x)[i32]; // assume at least 4 byte alignment
}
+// q4 contains 8 indices with 4 bit each.
+// This function selects those bytes from table that are at those indices and returns them as int2.
+// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4.
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
+#if defined(GGML_USE_HIP)
+ // Load the 16-byte table into four 32-bit unsigned integers.
+ const uint32_t *values = (const uint32_t *)table;
+
+ const uint32_t q_even = q4;
+ const uint32_t q_odd = (q4 >> 4);
+
+ // Perform lookups in the lower half of the table (indices 0-7).
+ uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707);
+ uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707);
+
+ // Perform lookups in the upper half of the table (indices 8-15).
+ uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707);
+ uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707);
+
+ // Select between the low and high results based on the MSB of each index nibble.
+ uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1);
+ uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even);
+ uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1);
+ uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd);
+
+ return make_int2(res_x, res_y);
+#elif !defined(GGML_USE_MUSA)
+ // CUDA does not have an instruction for selecting bytes with 4 bit indices.
+ // However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead.
+ const uint32_t * table32 = (const uint32_t *) table;
+
+ // __byte_perm selects bytes based on the lower 16 bits in its third argument.
+ // Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift.
+ // To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits.
+ // Then, call __byte_perm again to select from the low and high bytes based on the fourth bit.
+ uint32_t tmp[2];
+ const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1));
+#pragma unroll
+ for (uint32_t i = 0; i < 2; ++i) {
+ const uint32_t shift = 16 * i;
+
+ const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift);
+ const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);
+ tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift);
+ }
+
+ // tmp contains the bytes from tyble in the same order as the 4 bit indices in q4.
+ // However, for the result we need ints with all even/odd 4 bit indices in q4.
+ // Therefore, 2 more calls to __byte_perm to put the bytes in the correct order.
+ return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));
+#else
+ // Generic implementation.
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;
const char4 val0_8 = make_char4(
@@ -40,6 +91,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
+#endif
}
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 6e9c67aca096e..c6a33d5de310f 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -22,7 +22,10 @@
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
+#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
+#define __all_sync(mask, var) __all(var)
+#define __any_sync(mask, var) __any(var)
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
index fc6526d6d5dc6..b9d3639448500 100644
--- a/ggml/src/ggml-metal/ggml-metal-impl.h
+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
@@ -249,6 +249,7 @@ typedef struct {
uint64_t nb33;
int32_t ne1;
int32_t ne2;
+ int32_t ne3;
float scale;
float max_bias;
float m0;
@@ -257,6 +258,11 @@ typedef struct {
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext;
+typedef struct {
+ int32_t nrows;
+ int32_t ne20;
+} ggml_metal_kargs_flash_attn_ext_reduce;
+
typedef struct {
int32_t ne00;
int32_t ne02;
@@ -320,40 +326,31 @@ typedef struct {
} ggml_metal_kargs_mul_mv_ext;
typedef struct {
+ int32_t ne02;
int32_t ne10;
int32_t ne11; // n_expert_used (bcast)
uint64_t nb11;
uint64_t nb12;
- int32_t neh11; // n_tokens
- uint64_t nbh11;
+ int32_t ne21; // n_tokens
int32_t ne20; // n_expert_used
uint64_t nb21;
} ggml_metal_kargs_mul_mm_id_map0;
-typedef struct {
- int32_t ne20; // n_expert_used
- int32_t neh0;
- int32_t neh1;
- uint64_t nbh1;
- uint64_t nbh2;
- int32_t ne0;
- uint64_t nb1;
- uint64_t nb2;
-} ggml_metal_kargs_mul_mm_id_map1;
-
typedef struct {
int32_t ne00;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
- int32_t neh12;
- uint64_t nbh10;
- uint64_t nbh11;
- uint64_t nbh12;
- uint64_t nbh13;
- int32_t neh0;
- int32_t neh1;
+ int32_t ne11;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ int32_t ne20;
+ int32_t ne21;
+ int32_t ne0;
+ int32_t ne1;
int16_t r2;
int16_t r3;
} ggml_metal_kargs_mul_mm_id;
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index 7c70d352dfddf..3d16a1dcd461e 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -93,35 +93,37 @@
if (ctx->mtl_device == nil) {
ctx->mtl_device = MTLCreateSystemDefaultDevice();
- ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
- ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
+ if (ctx->mtl_device) {
+ ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
+ ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
- ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
+ ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
- ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
+ ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
#endif
- ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
- ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
+ ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
+ ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
#if defined(GGML_METAL_USE_BF16)
- ctx->use_bfloat = ctx->has_bfloat;
+ ctx->use_bfloat = ctx->has_bfloat;
#else
- ctx->use_bfloat = false;
+ ctx->use_bfloat = false;
#endif
- ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
+ ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
- {
- const char * val = getenv("GGML_METAL_FUSION_DEBUG");
- ctx->debug_fusion = val ? atoi(val) : 0;
- }
+ {
+ const char * val = getenv("GGML_METAL_FUSION_DEBUG");
+ ctx->debug_fusion = val ? atoi(val) : 0;
+ }
- memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
+ memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
- ctx->max_size = ctx->mtl_device.maxBufferLength;
+ ctx->max_size = ctx->mtl_device.maxBufferLength;
- strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
+ strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
+ }
}
ctx->mtl_device_ref_count++;
@@ -289,6 +291,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
@@ -396,8 +402,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
@@ -443,6 +453,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
@@ -452,6 +463,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
@@ -461,6 +473,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@@ -470,6 +483,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
@@ -479,6 +493,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
@@ -488,6 +503,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
@@ -497,6 +513,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
@@ -555,6 +572,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,
GGML_METAL_KERNEL_TYPE_SET_I32,
GGML_METAL_KERNEL_TYPE_SET_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
@@ -1304,6 +1322,10 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2, mul_mv_ext_f32_f32_r1_2, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3, mul_mv_ext_f32_f32_r1_3, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4, mul_mv_ext_f32_f32_r1_4, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5, mul_mv_ext_f32_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
@@ -1412,8 +1434,12 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
@@ -1459,6 +1485,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40, flash_attn_ext_f16_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
@@ -1468,6 +1495,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40, flash_attn_ext_bf16_h40, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
@@ -1477,6 +1505,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40, flash_attn_ext_q4_0_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
@@ -1486,6 +1515,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40, flash_attn_ext_q4_1_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
@@ -1495,6 +1525,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40, flash_attn_ext_q5_0_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
@@ -1504,6 +1535,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40, flash_attn_ext_q5_1_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
@@ -1513,6 +1545,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40, flash_attn_ext_q8_0_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
@@ -1571,6 +1604,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
@@ -1846,7 +1880,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_ROPE:
return true;
case GGML_OP_IM2COL:
- return op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
+ return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
case GGML_OP_POOL_1D:
return false;
case GGML_OP_UPSCALE:
@@ -1861,9 +1895,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_ARANGE:
return true;
case GGML_OP_FLASH_ATTN_EXT:
- if (op->src[0]->ne[0] == 32) {
- // head size == 32 (e.g. bert-bge-small)
- // TODO: not sure if it is worth adding kernels for this size
+ // for new head sizes, add checks here
+ if (op->src[0]->ne[0] != 40 &&
+ op->src[0]->ne[0] != 64 &&
+ op->src[0]->ne[0] != 80 &&
+ op->src[0]->ne[0] != 96 &&
+ op->src[0]->ne[0] != 112 &&
+ op->src[0]->ne[0] != 128 &&
+ op->src[0]->ne[0] != 192 &&
+ op->src[0]->ne[0] != 256) {
return false;
}
if (op->src[0]->ne[0] == 576) {
@@ -3347,15 +3387,16 @@ static int ggml_metal_encode_node(
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
- const int ne11_mm_min = 4;
+ const int ne11_mm_min = 8;
// first try to use small-batch mat-mv kernels
// these should be efficient for BS [2, ~8]
- if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
+ if (src1t == GGML_TYPE_F32 && (ne00%128 == 0) &&
(
(
(
- src0t == GGML_TYPE_F16 || // TODO: helper function
+ src0t == GGML_TYPE_F32 || // TODO: helper function
+ src0t == GGML_TYPE_F16 ||
src0t == GGML_TYPE_Q4_0 ||
src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q5_0 ||
@@ -3383,7 +3424,17 @@ static int ggml_metal_encode_node(
// values and there can be some tail effects when nsg is high. need to confirm this
//
const int nsg = 2; // num simdgroups per threadgroup
- const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
+
+ // num threads along row per simdgroup
+ int nxpsg = 0;
+ if (ne00 % 256 == 0 && ne11 < 3) {
+ nxpsg = 16;
+ } else if (ne00 % 128 == 0) {
+ nxpsg = 8;
+ } else {
+ nxpsg = 4;
+ }
+
const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
int r1ptg = 4; // num src1 rows per threadgroup
@@ -3406,6 +3457,14 @@ static int ggml_metal_encode_node(
id pipeline = nil;
switch (src0->type) {
+ case GGML_TYPE_F32:
+ switch (r1ptg) {
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4].pipeline; break;
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ } break;
case GGML_TYPE_F16:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
@@ -3560,7 +3619,7 @@ static int ggml_metal_encode_node(
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
- case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
@@ -3878,38 +3937,6 @@ static int ggml_metal_encode_node(
default: break;
}
- const int64_t neh10 = ne10; // n_embd
- const int64_t neh11 = ne21; // n_tokens
- const int64_t neh12 = ne02; // n_expert
-
- const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
- const uint64_t nbh11 = nbh10*neh10;
- const uint64_t nbh12 = nbh11*neh11;
- const uint64_t nbh13 = nbh12*neh12;
-
- const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
- id h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
- if (!h_src1) {
- GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
- return 0;
- }
-
- const int64_t neh0 = ne0;
- const int64_t neh1 = ne21;
- const int64_t neh2 = ne02;
-
- const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
- const uint64_t nbh1 = nbh0*neh0;
- const uint64_t nbh2 = nbh1*neh1;
- //const uint64_t nbh3 = nbh2*neh2;
-
- const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
- id h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
- if (!h_dst) {
- GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
- return 0;
- }
-
// tokens per expert
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
id h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
@@ -3919,8 +3946,8 @@ static int ggml_metal_encode_node(
}
// id map
- // [n_expert_used, n_tokens]
- const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
+ // [n_tokens, n_expert]
+ const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
id h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
if (!h_ids) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
@@ -3928,32 +3955,45 @@ static int ggml_metal_encode_node(
}
{
- const int nth = MIN(1024, ne10/4);
-
ggml_metal_kargs_mul_mm_id_map0 args = {
+ ne02,
ne10,
- ne11, // n_expert_used (bcast)
+ ne11, // n_expert_used (bcast)
nb11,
nb12,
- neh11, // n_tokens
- nbh11,
- ne20, // n_expert_used
+ ne21, // n_tokens
+ ne20, // n_expert_used
nb21,
};
id pipeline = nil;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
+ pipeline = nil;
+
+ switch (ne20) {
+ case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline; break;
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
+ case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
+ case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
+ case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
+ default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
+ }
+
+ GGML_ASSERT(ne02 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
+
+ const size_t smem = ne02*ne20*sizeof(uint16_t);
+
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
- [encoder setBuffer: h_src1 offset:0 atIndex:3];
- [encoder setBuffer: h_tpe offset:0 atIndex:4];
- [encoder setBuffer: h_ids offset:0 atIndex:5];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
+ [encoder setBuffer: h_tpe offset:0 atIndex:2];
+ [encoder setBuffer: h_ids offset:0 atIndex:3];
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
}
{
@@ -3992,13 +4032,15 @@ static int ggml_metal_encode_node(
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
- /*.neh12 =*/ neh12,
- /*.nbh10 =*/ nbh10,
- /*.nbh11 =*/ nbh11,
- /*.nbh12 =*/ nbh12,
- /*.nbh13 =*/ nbh13,
- /*.neh0 =*/ neh0,
- /*.neh1 =*/ neh1,
+ /*.ne11 =*/ ne11, // n_expert_used (bcast)
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ne20 =*/ ne20, // n_expert_used
+ /*.ne21 =*/ ne21, // n_tokens
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
/*.r2 =*/ r2,
/*.r3 =*/ r3,
};
@@ -4006,42 +4048,14 @@ static int ggml_metal_encode_node(
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
- [encoder setBuffer: h_src1 offset:0 atIndex:2];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer: h_tpe offset:0 atIndex:3];
- [encoder setBuffer: h_dst offset:0 atIndex:4];
+ [encoder setBuffer: h_ids offset:0 atIndex:4];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:5];
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
}
-
- {
- GGML_ASSERT(ne0 % 4 == 0);
-
- const int nth = MIN(1024, ne0/4);
-
- ggml_metal_kargs_mul_mm_id_map1 args = {
- ne20, // n_expert_used
- neh0,
- neh1,
- nbh1,
- nbh2,
- ne0,
- nb1,
- nb2,
- };
-
- id pipeline = nil;
-
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
- [encoder setBuffer: h_dst offset:0 atIndex:1];
- [encoder setBuffer: h_ids offset:0 atIndex:2];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- }
} else {
id pipeline = nil;
@@ -4701,7 +4715,6 @@ static int ggml_metal_encode_node(
} break;
case GGML_OP_IM2COL:
{
- GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
@@ -5117,10 +5130,8 @@ static int ggml_metal_encode_node(
bool use_vec_kernel = false;
- // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
- // for now avoiding mainly to keep the number of templates/kernels a bit lower
- // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
- if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
+ // use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
+ if (ne01 >= 20 || (ne00 == 40 || ne00 == 80 || ne00 == 112)) {
switch (src1->type) {
case GGML_TYPE_F16:
{
@@ -5130,6 +5141,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
} else {
switch (ne00) {
+ case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
@@ -5154,6 +5166,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
} else {
switch (ne00) {
+ case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
@@ -5178,6 +5191,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
} else {
switch (ne00) {
+ case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
@@ -5202,6 +5216,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
} else {
switch (ne00) {
+ case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
@@ -5226,6 +5241,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
} else {
switch (ne00) {
+ case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
@@ -5250,6 +5266,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
} else {
switch (ne00) {
+ case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
@@ -5274,6 +5291,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
} else {
switch (ne00) {
+ case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
@@ -5465,6 +5483,7 @@ static int ggml_metal_encode_node(
/*.nb33 =*/ nb33,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
/*.scale =*/ scale,
/*.max_bias =*/ max_bias,
/*.m0 =*/ m0,
@@ -5488,7 +5507,6 @@ static int ggml_metal_encode_node(
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
}
- [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
if (!use_vec_kernel) {
// half8x8 kernel
@@ -5514,7 +5532,7 @@ static int ggml_metal_encode_node(
while (true) {
const size_t smem = FATTN_SMEM(nsgmax);
- if (smem > device.maxThreadgroupMemoryLength) {
+ if (smem > device.maxThreadgroupMemoryLength/2) {
break;
}
nsgmax *= 2;
@@ -5526,15 +5544,18 @@ static int ggml_metal_encode_node(
const size_t smem = FATTN_SMEM(nsg);
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
+
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
-#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+#undef FATTN_SMEM
} else {
// half4x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+ const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 1 == 0);
@@ -5544,15 +5565,17 @@ static int ggml_metal_encode_node(
// for each query, we load it as f16 in shared memory (ne00)
// and store the soft_max values and the mask
//
- // ne00*(nsg)
+ // ne20*(nsg)
// each simdgroup has a full f32 head vector in shared mem to accumulate results
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
+//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16))
int64_t nsgmax = 2;
while (true) {
const size_t smem = FATTN_SMEM(nsgmax);
- if (smem > device.maxThreadgroupMemoryLength) {
+ // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
+ if (smem > device.maxThreadgroupMemoryLength/2) {
break;
}
nsgmax *= 2;
@@ -5560,7 +5583,7 @@ static int ggml_metal_encode_node(
nsgmax /= 2;
// simdgroups per threadgroup (a.k.a. warps)
- const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
+ const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
int64_t nsg = 1;
while (nsg <= nsgt) {
@@ -5568,13 +5591,74 @@ static int ggml_metal_encode_node(
}
nsg /= 2;
- const size_t smem = FATTN_SMEM(nsg);
+ // workgroups
+ // each workgroup handles nsg*nkpsg cache values
+ uint16_t nwg = 1;
+ if (4*nsg*nkpsg >= ne11) {
+ const size_t smem = FATTN_SMEM(nsg);
- //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
- GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
- [encoder setThreadgroupMemoryLength:smem atIndex:0];
+ //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+
+ // using 1 workgroup -> write the result directly into dst
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
+ [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
+
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ } else {
+ nwg = 32;
+ nsg = MIN(4, nsg);
+
+ const size_t smem = FATTN_SMEM(nsg);
+
+ //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+
+ // sanity checks
+ GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
+ GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
+
+ const int32_t nrows = ne1*ne2*ne3;
+
+ // temp buffer for writing the results from each workgroup
+ // - ne20: the size of the head vector
+ // - + 2: the S and M values for each intermediate result
+ const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2));
+ id h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp);
+ if (!h_tmp) {
+ GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp);
+ return 0;
+ }
+
+ //printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
+ //printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
+
+ [encoder setBuffer:h_tmp offset:0 atIndex:6];
+ [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
+
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+
+ // reduce the results from the workgroups
+ {
+ ggml_metal_kargs_flash_attn_ext_reduce args0 = {
+ nrows,
+ ne20,
+ };
+
+ id pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline;
+
+ [encoder setComputePipelineState:pipeline0];
+ [encoder setBytes:&args0 length:sizeof(args0) atIndex:0];
+ [encoder setBuffer:h_tmp offset:0 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+
+ //printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)];
+ }
+ }
#undef FATTN_SMEM
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
}
} break;
case GGML_OP_DUP:
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index b35a3bbdc317f..9c5933d24a0e3 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -68,6 +68,11 @@ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg)
reg = (type4x4)(*src);
}
+template
+void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
+ reg = (type4)(*src);
+}
+
template
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
@@ -974,9 +979,16 @@ kernel void kernel_mul(
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int i10 = i0%args.ne10;
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
+ if (args.ne10 == 1) {
+ const float x = *((device float *)(src1_ptr));
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
+ }
+ } else {
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i10 = i0%args.ne10;
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
+ }
}
}
@@ -1000,9 +1012,16 @@ kernel void kernel_div(
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
- for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
- const int i10 = i0%args.ne10;
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
+ if (args.ne10 == 1) {
+ const float x = 1.0f / *((device float *)(src1_ptr));
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
+ }
+ } else {
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i10 = i0%args.ne10;
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
+ }
}
}
@@ -1964,14 +1983,15 @@ kernel void kernel_ssm_scan_f32(
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int64_t i = i0 + i1*nc;
+ const int64_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i];
float s = s_buff[i];
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
- device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
- device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
+ device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
+ device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
@@ -2079,14 +2099,15 @@ kernel void kernel_ssm_scan_f32_group(
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int64_t i = i0 + i1*nc;
+ const int64_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i];
float s = s_buff[i];
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
- device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
- device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
+ device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
+ device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
@@ -3001,7 +3022,6 @@ void kernel_mul_mv_ext_q4_f32_impl(
#pragma unroll(r1ptg)
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
-
}
}
@@ -3186,6 +3206,11 @@ kernel void kernel_mul_mv_ext_q4x4_f32_disp(
typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t;
+template [[host_name("kernel_mul_mv_ext_f32_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4, 4, dequantize_f32_t4>;
+template [[host_name("kernel_mul_mv_ext_f32_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4, 4, dequantize_f32_t4>;
+template [[host_name("kernel_mul_mv_ext_f32_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4, 4, dequantize_f32_t4>;
+template [[host_name("kernel_mul_mv_ext_f32_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4, 4, dequantize_f32_t4>;
+
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
@@ -4663,6 +4688,7 @@ kernel void kernel_flash_attn_ext(
typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t;
+template [[host_name("kernel_flash_attn_ext_f16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -4674,6 +4700,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_bf16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -4685,6 +4712,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
#endif
+template [[host_name("kernel_flash_attn_ext_q4_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -4695,6 +4723,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -4705,6 +4734,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -4715,6 +4745,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -4725,6 +4756,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -4765,14 +4797,19 @@ kernel void kernel_flash_attn_ext_vec(
device const char * mask,
device const char * sinks,
device char * dst,
+ constant uint16_t & nwg,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ static_assert(DK % 32 == 0, "DK must be divisible by 32");
+ static_assert(DV % 32 == 0, "DV must be divisible by 32");
+
const short nsg = ntg.y; // number of simdgroups
+ const short iwg = tgpig[2]%nwg;
- const int iq3 = tgpig[2];
+ const int iq3 = tgpig[2]/nwg;
const int iq2 = tgpig[1];
const int iq1 = tgpig[0];
@@ -4851,7 +4888,7 @@ kernel void kernel_flash_attn_ext_vec(
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
- for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
+ for (int ic0 = (int) iwg*C*nsg; ic0 < args.ne11; ic0 += (int) nwg*C*nsg) {
const int ic = ic0 + C*sgitg;
if (ic >= args.ne11) {
break;
@@ -4981,7 +5018,7 @@ kernel void kernel_flash_attn_ext_vec(
}
}
- if (sinks != q && sgitg == 0) {
+ if (sinks != q && sgitg == 0 && iwg == 0) {
const float m = M;
const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
@@ -5090,14 +5127,25 @@ kernel void kernel_flash_attn_ext_vec(
threadgroup_barrier(mem_flags::mem_threadgroup);
}
- device float4 * dst4 = (device float4 *) dst;
-
// final rescale with 1/S and store to global memory
if (sgitg == 0) {
- const float S = ss[0];
+ const int64_t nrows = args.ne3*args.ne2*args.ne1;
+ const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1;
+
+ device float4 * dst4 = (device float4 *) dst;
+ device float * dst1 = (device float *) dst + nrows*DV*nwg; // the S and M are stored after the results
+
+ const float S = nwg == 1 ? 1.0f/ss[0] : 1.0f;
+ // interleave the workgroup data
for (short i = tiisg; i < DV4; i += NW) {
- dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
+ dst4[rid*DV4*nwg + nwg*i + iwg] = (float4) sr4[i]*S;
+ }
+
+ // store S and M
+ if (nwg > 1 && tiisg == 0) {
+ dst1[rid*(2*nwg) + 2*iwg + 0] = ss[0];
+ dst1[rid*(2*nwg) + 2*iwg + 1] = ss[1];
}
}
}
@@ -5187,6 +5235,41 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flas
#undef FA_TYPES
+kernel void kernel_flash_attn_ext_reduce(
+ constant ggml_metal_kargs_flash_attn_ext_reduce & args,
+ device const char * htmp,
+ device char * dst,
+ uint tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ const uint64_t rid = tgpig;
+
+ const short nwg = 32;
+ const short iwg = tiisg;
+ const short DV = args.ne20;
+ const short DV4 = DV/4;
+
+ device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*nwg;
+ device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*nwg;
+ device float4 * dst4 = (device float4 *) dst + rid*DV4;
+
+ float S = ss[rid*(2*nwg) + 2*iwg + 0];
+ float M = ss[rid*(2*nwg) + 2*iwg + 1];
+
+ const float m = simd_max(M);
+ const float ms = exp(M - m);
+
+ S = 1.0f/simd_sum(S*ms);
+
+ for (int i = sgitg; i < DV4; i += nwg) {
+ const float4 v = simd_sum(htmp4[i*nwg + iwg]*ms);
+
+ if (iwg == 0) {
+ dst4[i] = v*S;
+ }
+ }
+}
+
template
kernel void kernel_set(
constant ggml_metal_kargs_set & args,
@@ -7474,97 +7557,81 @@ kernel void kernel_mul_mm(
}
}
-template
+template // n_expert_used
kernel void kernel_mul_mm_id_map0(
constant ggml_metal_kargs_mul_mm_id_map0 & args,
- device const char * src1,
device const char * src2,
- device char * hsrc1,
device char * htpe,
device char * hids,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int ide = tgpig[0]; // expert id
-
- int n_all = 0;
-
- device int32_t * ids_i32 = (device int32_t *) (hids);
+ threadgroup char * shmem [[threadgroup(0)]],
+ ushort tpitg[[thread_position_in_threadgroup]],
+ ushort ntg[[threads_per_threadgroup]]) {
+ const short ide = tpitg; // expert id
- for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
- device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
+ uint32_t n_all = 0;
- for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
- if (src2_i32[i20] != ide) {
- continue;
- }
+ device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
- device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
- device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
+ for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
+ if (i21 + tpitg < args.ne21) {
+ device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
- for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
- hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
- }
+ threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
- if (tpitg.x == 0) {
- ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
+ #pragma unroll(ne20)
+ for (short i20 = 0; i20 < ne20; i20++) {
+ sids[i20] = src2_i32[i20];
}
-
- ++n_all;
}
- }
- if (tpitg.x == 0) {
- device int32_t * tpe_i32 = (device int32_t *) (htpe);
- tpe_i32[ide] = n_all;
- }
-}
-
-typedef decltype(kernel_mul_mm_id_map0) kernel_mul_mm_id_map0_t;
-
-template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0;
+ threadgroup_barrier(mem_flags::mem_threadgroup);
-template
-kernel void kernel_mul_mm_id_map1(
- constant ggml_metal_kargs_mul_mm_id_map1 & args,
- device const char * hdst,
- device const char * hids,
- device char * dst,
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
- const int i20 = tgpig[0]; // used expert
- const int i21 = tgpig[1]; // token
+ for (short t = 0; t < ntg; t++) {
+ if (i21 + t >= args.ne21) {
+ break;
+ }
- device const int32_t * ids_i32 = (device const int32_t *) (hids);
- device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
+ threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
- const int id = ids_i32[i21*args.ne20 + i20];
+ short sel = 0;
+ #pragma unroll(ne20)
+ for (short i20 = 0; i20 < ne20; i20++) {
+ sel += (sids[i20] == ide)*(i20 + 1);
+ }
- const int ide = id / args.neh1;
- const int idt = id % args.neh1;
+ ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
- device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
+ n_all += sel > 0;
+ }
- for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
- dst_f32x4[i0] = hdst_f32x4[i0];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
}
+
+ device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
+ tpe_u32[ide] = n_all;
}
-typedef decltype(kernel_mul_mm_id_map1) kernel_mul_mm_id_map1_t;
+typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
-template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template
kernel void kernel_mul_mm_id(
constant ggml_metal_kargs_mul_mm_id & args,
device const char * src0,
device const char * src1,
- device const char * tpe,
+ device const char * htpe,
+ device const char * hids,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup T * sa = (threadgroup T *)(shmem);
@@ -7572,19 +7639,20 @@ kernel void kernel_mul_mm_id(
const int r0 = tgpig.y;
const int r1 = tgpig.x;
- const int im = tgpig.z;
+ const int im = tgpig.z; // expert
- device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
+ device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
+ device const int32_t * ids_i32 = (device const int32_t *) (hids);
- const int neh1 = tpe_i32[im];
+ const int32_t neh1 = tpe_u32[im];
if (r1*BLOCK_SIZE_N >= neh1) {
return;
}
// if this block is of 64x32 shape or smaller
- const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
- const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
+ const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
+ const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
// a thread shouldn't load data outside of the matrix
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
@@ -7600,20 +7668,23 @@ kernel void kernel_mul_mm_id(
short il = (tiitg % THREAD_PER_ROW);
- const int i12 = im%args.neh12;
- const int i13 = im/args.neh12;
+ const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
- const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const short i11 = (id % args.ne20) % args.ne11;
+ const short i12 = (id / args.ne20);
+ const short i13 = 0;
+
+ const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
const short offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
- device const half * y = (device const half *)(src1
- + args.nbh13*i13
- + args.nbh12*i12
- + args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
- + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+ device const float * y = (device const float *)(src1
+ + args.nb13*i13
+ + args.nb12*i12
+ + args.nb11*i11
+ + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory
@@ -7629,7 +7700,7 @@ kernel void kernel_mul_mm_id(
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
}
- *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
+ *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y));
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
@@ -7665,43 +7736,38 @@ kernel void kernel_mul_mm_id(
}
}
- if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
- device float * C = (device float *) dst +
- (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
- (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- for (short i = 0; i < 8; i++) {
- simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
- }
- } else {
- // block is smaller than 64x32, we should avoid writing data outside of the matrix
- threadgroup_barrier(mem_flags::mem_threadgroup);
- threadgroup float * temp_str = ((threadgroup float *) shmem) \
- + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
- for (short i = 0; i < 8; i++) {
- simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
- }
+ threadgroup float * temp_str = ((threadgroup float *) shmem) \
+ + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ #pragma unroll(8)
+ for (short i = 0; i < 8; i++) {
+ simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
+ }
- if (sgitg == 0) {
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
- device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
- device float4 * D4 = (device float4 *) D;
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
- threadgroup float4 * C4 = (threadgroup float4 *) C;
+ for (short j = sgitg; j < n_cols; j += 4) {
+ const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
- int i = 0;
- for (; i < n_rows/4; i++) {
- *(D4 + i) = *(C4 + i);
- }
+ const short ide = id % args.ne20;
+ const short idt = id / args.ne20;
- i *= 4;
- for (; i < n_rows; i++) {
- *(D + i) = *(C + i);
- }
- }
+ device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
+ device float4 * D4 = (device float4 *) D;
+
+ threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M);
+ threadgroup float4 * C4 = (threadgroup float4 *) C;
+
+ int i = tiisg;
+ for (; i < n_rows/4; i += 32) {
+ *(D4 + i) = *(C4 + i);
+ }
+
+ i = (4*(n_rows/4)) + tiisg;
+ for (; i < n_rows; i += 32) {
+ *(D + i) = *(C + i);
}
}
}
diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
index df27501361f7f..a9a91ca585b2d 100644
--- a/ggml/src/ggml-opencl/ggml-opencl.cpp
+++ b/ggml/src/ggml-opencl/ggml-opencl.cpp
@@ -420,9 +420,9 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_clamp;
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
- cl_kernel kernel_norm;
+ cl_kernel kernel_norm, kernel_norm_mul_add;
cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
- cl_kernel kernel_group_norm;
+ cl_kernel kernel_group_norm, kernel_group_norm_mul_add;
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
cl_kernel kernel_soft_max, kernel_soft_max_4;
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
@@ -1161,7 +1161,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
backend_ctx->program_norm =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
- CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err));
+ CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err));
+ CL_CHECK((backend_ctx->kernel_norm_mul_add = clCreateKernel(backend_ctx->program_norm, "kernel_norm_mul_add", &err), err));
GGML_LOG_CONT(".");
}
@@ -1487,7 +1488,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
backend_ctx->program_group_norm =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
- CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err));
+ CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err));
+ CL_CHECK((backend_ctx->kernel_group_norm_mul_add = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm_mul_add", &err), err));
GGML_LOG_CONT(".");
}
@@ -2498,12 +2500,47 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
return false;
}
+ } else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) {
+ const ggml_tensor *norm = cgraph->nodes[node_idx];
+ const ggml_tensor *mul = cgraph->nodes[node_idx+1];
+ const ggml_tensor *add = cgraph->nodes[node_idx+2];
+ const ggml_tensor *w = mul->src[0] == norm ? mul->src[1] : mul->src[0];
+ const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0];
+
+ // norm fusion only supports F32
+ if (norm->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ if (norm->src[0]->ne[0] % 4 != 0) {
+ return false;
+ }
+
+ if (!ggml_is_contiguous(norm->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) {
+ return false;
+ }
+ } else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_GROUP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) {
+ const ggml_tensor *gn = cgraph->nodes[node_idx];
+ const ggml_tensor *mul = cgraph->nodes[node_idx+1];
+ const ggml_tensor *add = cgraph->nodes[node_idx+2];
+ const ggml_tensor *w = mul->src[0] == gn ? mul->src[1] : mul->src[0];
+ const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0];
+
+ if (gn->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ if (!ggml_is_contiguous(gn->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) {
+ return false;
+ }
}
return true;
}
static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
+static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
+static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
@@ -2520,6 +2557,16 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
continue;
}
+ if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
+ ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
+ i += 2;
+ continue;
+ }
+ if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_GROUP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
+ ggml_opencl_op_group_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
+ i += 2;
+ continue;
+ }
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]);
i++;
@@ -2647,8 +2694,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SOFT_MAX:
case GGML_OP_NORM:
- case GGML_OP_RMS_NORM:
return true;
+ case GGML_OP_RMS_NORM:
+ return op->ne[0] % 4 == 0 && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_REPEAT:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
case GGML_OP_PAD:
@@ -2728,10 +2776,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
case GGML_OP_FLASH_ATTN_EXT:
{
- if (op->src[4]) {
- return false;
- }
-
const ggml_tensor * q = op->src[0];
const ggml_tensor * k = op->src[1];
const ggml_tensor * v = op->src[2];
@@ -5038,6 +5082,140 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
+static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
+ GGML_ASSERT(norm_tensor && mul_tensor && add_tensor);
+
+ const ggml_tensor * src0 = norm_tensor->src[0];
+ const ggml_tensor * src1 = mul_tensor->src[0] == norm_tensor ? mul_tensor->src[1] : mul_tensor->src[0];
+ const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0];
+ const ggml_tensor * dst = add_tensor;
+
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+ ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
+ cl_ulong offset1 = extra1->offset + src1->view_offs;
+ cl_ulong offset2 = extra2->offset + src2->view_offs;
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+ float eps;
+ memcpy(&eps, norm_tensor->op_params, sizeof(float));
+
+ const int ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3];
+ const cl_ulong nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3];
+ const int ne10 = src1->ne[0], ne11 = src1->ne[1], ne12 = src1->ne[2], ne13 = src1->ne[3];
+ const cl_ulong nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3];
+ const int ne20 = src2->ne[0], ne21 = src2->ne[1], ne22 = src2->ne[2], ne23 = src2->ne[3];
+ const cl_ulong nb21 = src2->nb[1], nb22 = src2->nb[2], nb23 = src2->nb[3];
+ const cl_ulong nbd1 = dst->nb[1], nbd2 = dst->nb[2], nbd3 = dst->nb[3];
+
+ size_t sgs;
+ if (backend_ctx->gpu_family == ADRENO) sgs = 64;
+ else if (backend_ctx->gpu_family == INTEL) sgs = 32;
+ else GGML_ASSERT(false && "Unsupported GPU");
+
+ cl_kernel kernel = backend_ctx->kernel_norm_mul_add;
+
+ int nth = sgs;
+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
+ while (nth < ne00/4 && nth < max_workgroup_size) nth *= 2;
+ nth = MIN(nth, max_workgroup_size);
+ nth = MIN(nth, ne00/4);
+
+ size_t gws[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
+ size_t lws[] = {(size_t)nth, 1, 1};
+ size_t num_subgroups = (nth + sgs - 1) / sgs;
+
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne03));
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb01));
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb02));
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb03));
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne10));
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne11));
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne12));
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne13));
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne20));
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne21));
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne22));
+ CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne23));
+ CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb21));
+ CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb22));
+ CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb23));
+ CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nbd1));
+ CL_CHECK(clSetKernelArg(kernel, 30, sizeof(cl_ulong), &nbd2));
+ CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_ulong), &nbd3));
+ CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &eps));
+ CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_float2) * num_subgroups, NULL));
+
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, gws, lws, dst);
+}
+
+static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
+ GGML_ASSERT(gn_tensor && mul_tensor && add_tensor);
+
+ const ggml_tensor * src0 = gn_tensor->src[0];
+ const ggml_tensor * src1 = mul_tensor->src[0] == gn_tensor ? mul_tensor->src[1] : mul_tensor->src[0];
+ const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0];
+ const ggml_tensor * dst = add_tensor;
+
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+ ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
+ cl_ulong offset1 = extra1->offset + src1->view_offs;
+ cl_ulong offset2 = extra2->offset + src2->view_offs;
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+ int groups;
+ float eps;
+ memcpy(&groups, gn_tensor->op_params, sizeof(int));
+ memcpy(&eps, (char *)gn_tensor->op_params + sizeof(int), sizeof(float));
+
+ cl_kernel kernel = backend_ctx->kernel_group_norm_mul_add;
+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
+ int ne = ggml_nelements(src0);
+ int group_size = ne / groups;
+
+ size_t lws[] = { (size_t)MIN(max_workgroup_size, group_size) };
+ size_t gws[] = { (size_t)groups * lws[0] };
+
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne));
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &group_size));
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &eps));
+
+ backend_ctx->enqueue_ndrange_kernel(kernel, 1, gws, lws, dst);
+}
+
static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@@ -5583,6 +5761,7 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {
const ggml_tensor * v = dst->src[2];
const ggml_tensor * mask = dst->src[3];
+ const ggml_tensor * sinks = dst->src[4];
GGML_ASSERT(q->extra);
GGML_ASSERT(k->extra);
GGML_ASSERT(v->extra);
@@ -5590,6 +5769,9 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
if (mask) {
GGML_ASSERT(mask->extra);
}
+ if (sinks) {
+ GGML_ASSERT(sinks->extra);
+ }
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
@@ -5631,6 +5813,7 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra;
ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra;
ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL;
+ ggml_tensor_extra_cl * extra_sinks = sinks ? (ggml_tensor_extra_cl *)sinks->extra : NULL;
cl_ulong offset_q = extra_q->offset + q->view_offs;
cl_ulong offset_k = extra_k->offset + k->view_offs;
@@ -5638,6 +5821,8 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
cl_ulong offset_o = extra_o->offset + dst->view_offs;
cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL;
cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0;
+ cl_mem sinks_buffer = extra_sinks ? extra_sinks->data_device : NULL;
+ cl_ulong offset_sinks = extra_sinks ? extra_sinks->offset + sinks->view_offs : 0;
const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3];
const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3];
@@ -5692,6 +5877,8 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3));
CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2));
CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3));
+ CL_CHECK(clSetKernelArg(kernel, 38, sizeof(cl_mem), &sinks_buffer));
+ CL_CHECK(clSetKernelArg(kernel, 39, sizeof(cl_ulong), &offset_sinks));
if (n_q == 1) {
const size_t wg_size = 64;
diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl
index fea06867e1020..8f43c4f27d58c 100644
--- a/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl
+++ b/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl
@@ -49,7 +49,9 @@ __kernel void flash_attn_f16(
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
- const int mask_ne3
+ const int mask_ne3,
+ const global void* sinks_void,
+ const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0);
@@ -171,6 +173,20 @@ __kernel void flash_attn_f16(
}
if (my_query_row < n_q) {
+ if (sinks_void != NULL) {
+ const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+ const ACC_TYPE m_sink = sinks_ptr[head_idx];
+ const ACC_TYPE m_final = max(m_i, m_sink);
+
+ const ACC_TYPE scale_o = exp(m_i - m_final);
+ #pragma unroll
+ for (int i = 0; i < DV_VEC; ++i) {
+ o_acc[i] *= scale_o;
+ }
+
+ l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
+ }
+
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
@@ -214,7 +230,9 @@ __kernel void flash_attn_f16_q1(
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
- const int mask_ne3
+ const int mask_ne3,
+ const global void* sinks_void,
+ const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1);
@@ -247,7 +265,12 @@ __kernel void flash_attn_f16_q1(
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
- ACC_TYPE m_i = -INFINITY;
+ const global ACC_TYPE* sinks_ptr = NULL;
+ if (sinks_void != NULL) {
+ sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+ }
+
+ ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
@@ -320,7 +343,11 @@ __kernel void flash_attn_f16_q1(
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
- const ACC_TYPE l_final = local_l[0];
+ ACC_TYPE l_final = local_l[0];
+
+ if (sinks_ptr != NULL) {
+ l_final += exp(sinks_ptr[head_idx] - m_final);
+ }
if (l_final > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_final;
diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl
index 2d657327d6460..9c0bab135a912 100644
--- a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl
+++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl
@@ -49,7 +49,9 @@ __kernel void flash_attn_f32(
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
- const int mask_ne3
+ const int mask_ne3,
+ const global void* sinks_void,
+ const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0);
@@ -171,6 +173,20 @@ __kernel void flash_attn_f32(
}
if (my_query_row < n_q) {
+ if (sinks_void != NULL) {
+ const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+ const ACC_TYPE m_sink = sinks_ptr[head_idx];
+ const ACC_TYPE m_final = max(m_i, m_sink);
+
+ const ACC_TYPE scale_o = exp(m_i - m_final);
+ #pragma unroll
+ for (int i = 0; i < DV_VEC; ++i) {
+ o_acc[i] *= scale_o;
+ }
+
+ l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
+ }
+
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
@@ -214,7 +230,9 @@ __kernel void flash_attn_f32_q1(
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
- const int mask_ne3
+ const int mask_ne3,
+ const global void* sinks_void,
+ const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1);
@@ -247,7 +265,12 @@ __kernel void flash_attn_f32_q1(
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
- ACC_TYPE m_i = -INFINITY;
+ const global ACC_TYPE* sinks_ptr = NULL;
+ if (sinks_void != NULL) {
+ sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+ }
+
+ ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
@@ -320,7 +343,11 @@ __kernel void flash_attn_f32_q1(
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
- const ACC_TYPE l_final = local_l[0];
+ ACC_TYPE l_final = local_l[0];
+
+ if (sinks_ptr != NULL) {
+ l_final += exp(sinks_ptr[head_idx] - m_final);
+ }
if (l_final > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_final;
diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl
index 7067bd2591fa7..ec7361b9e3709 100644
--- a/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl
+++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl
@@ -52,7 +52,9 @@ __kernel void flash_attn_f32_f16(
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
- const int mask_ne3
+ const int mask_ne3,
+ const global void* sinks_void,
+ const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0);
@@ -174,6 +176,20 @@ __kernel void flash_attn_f32_f16(
}
if (my_query_row < n_q) {
+ if (sinks_void != NULL) {
+ const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+ const ACC_TYPE m_sink = sinks_ptr[head_idx];
+ const ACC_TYPE m_final = max(m_i, m_sink);
+
+ const ACC_TYPE scale_o = exp(m_i - m_final);
+ #pragma unroll
+ for (int i = 0; i < DV_VEC; ++i) {
+ o_acc[i] *= scale_o;
+ }
+
+ l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
+ }
+
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
@@ -217,7 +233,9 @@ __kernel void flash_attn_f32_f16_q1(
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
- const int mask_ne3
+ const int mask_ne3,
+ const global void* sinks_void,
+ const ulong sinks_offset
) {
const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1);
@@ -250,7 +268,12 @@ __kernel void flash_attn_f32_f16_q1(
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
- ACC_TYPE m_i = -INFINITY;
+ const global ACC_TYPE* sinks_ptr = NULL;
+ if (sinks_void != NULL) {
+ sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+ }
+
+ ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
@@ -323,7 +346,11 @@ __kernel void flash_attn_f32_f16_q1(
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
- const ACC_TYPE l_final = local_l[0];
+ ACC_TYPE l_final = local_l[0];
+
+ if (sinks_ptr != NULL) {
+ l_final += exp(sinks_ptr[head_idx] - m_final);
+ }
if (l_final > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_final;
diff --git a/ggml/src/ggml-opencl/kernels/group_norm.cl b/ggml/src/ggml-opencl/kernels/group_norm.cl
index 57c9df4d35b09..8e4fa0ed12d11 100644
--- a/ggml/src/ggml-opencl/kernels/group_norm.cl
+++ b/ggml/src/ggml-opencl/kernels/group_norm.cl
@@ -70,3 +70,52 @@ kernel void kernel_group_norm(
dst[j] *= scale;
}
}
+
+//------------------------------------------------------------------------------
+// group_norm_mul_add
+//------------------------------------------------------------------------------
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_32
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_group_norm_mul_add(
+ global float * src0, ulong offset0,
+ global float * src1, ulong offset1,
+ global float * src2, ulong offset2,
+ global float * dst, ulong offsetd,
+ int ne,
+ int group_size,
+ float eps
+) {
+ src0 = (global float *)((global char *)src0 + offset0);
+ src1 = (global float *)((global char *)src1 + offset1);
+ src2 = (global float *)((global char *)src2 + offset2);
+ dst = (global float *)((global char *)dst + offsetd);
+
+ int start = get_group_id(0) * group_size;
+ int end = start + group_size;
+ if (end > ne) {
+ end = ne;
+ }
+
+ float sum = 0.0f;
+ float sum_sq = 0.0f;
+
+ for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) {
+ float val = src0[j];
+ sum += val;
+ sum_sq += val*val;
+ }
+
+ sum = sub_group_reduce_add(sum);
+ sum_sq = sub_group_reduce_add(sum_sq);
+
+ const float mean = sum / group_size;
+ const float var = sum_sq / group_size - mean * mean;
+ const float scale = rsqrt(var + eps);
+
+ for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) {
+ dst[j] = ((src0[j] - mean) * scale) * src1[j] + src2[j];
+ }
+}
diff --git a/ggml/src/ggml-opencl/kernels/norm.cl b/ggml/src/ggml-opencl/kernels/norm.cl
index 43167ba4d2212..170f822787be2 100644
--- a/ggml/src/ggml-opencl/kernels/norm.cl
+++ b/ggml/src/ggml-opencl/kernels/norm.cl
@@ -79,3 +79,83 @@ kernel void kernel_norm(
y[i00] = y[i00] * scale;
}
}
+
+//------------------------------------------------------------------------------
+// norm_mul_add
+//------------------------------------------------------------------------------
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_32
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_norm_mul_add(
+ global char * src0_ptr, ulong src0_offset,
+ global char * src1_ptr, ulong src1_offset,
+ global char * src2_ptr, ulong src2_offset,
+ global char * dst_ptr, ulong dst_offset,
+ int ne00, int ne01, int ne02, int ne03,
+ ulong nb01, ulong nb02, ulong nb03,
+ int ne10, int ne11, int ne12, int ne13,
+ ulong nb11, ulong nb12, ulong nb13,
+ int ne20, int ne21, int ne22, int ne23,
+ ulong nb21, ulong nb22, ulong nb23,
+ ulong nbd1, ulong nbd2, ulong nbd3,
+ float eps,
+ local float2 * sums
+) {
+ const int i03 = get_group_id(2);
+ const int i02 = get_group_id(1);
+ const int i01 = get_group_id(0);
+
+ global float4 * x = (global float4 *)(src0_ptr + src0_offset + i01*nb01 + i02*nb02 + i03*nb03);
+ global float4 * w = (global float4 *)(src1_ptr + src1_offset + (i01%ne11)*nb11 + (i02%ne12)*nb12 + (i03%ne13)*nb13);
+ global float4 * b = (global float4 *)(src2_ptr + src2_offset + (i01%ne21)*nb21 + (i02%ne22)*nb22 + (i03%ne23)*nb23);
+ global float4 * y = (global float4 *)(dst_ptr + dst_offset + i01*nbd1 + i02*nbd2 + i03*nbd3);
+
+ float p_sum = 0.0f;
+ float p_sum_sq = 0.0f;
+
+ const int n_chunks = ne00 / 4;
+ for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
+ float4 val = x[i00];
+ p_sum += val.x + val.y + val.z + val.w;
+ p_sum_sq += dot(val, val);
+ }
+
+ p_sum = sub_group_reduce_add(p_sum);
+ p_sum_sq = sub_group_reduce_add(p_sum_sq);
+
+ if (get_sub_group_local_id() == 0) {
+ sums[get_sub_group_id()] = (float2)(p_sum, p_sum_sq);
+ }
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ if (get_local_id(0) == 0) {
+ float sum = 0.0f;
+ float sum_sq = 0.0f;
+ for (uint i = 0; i < get_num_sub_groups(); ++i) {
+ float2 s = sums[i];
+ sum += s.x;
+ sum_sq += s.y;
+ }
+
+ const float inv_ne00 = 1.0f / (float)ne00;
+ const float mean = sum * inv_ne00;
+ const float variance = mad(-mean, mean, sum_sq * inv_ne00);
+
+ sums[0] = (float2)(mean, rsqrt(variance + eps));
+ }
+ barrier(CLK_LOCAL_MEM_FENCE);
+
+ const float2 mean_scale = sums[0];
+ const float mean = mean_scale.x;
+ const float scale = mean_scale.y;
+ const float neg_mean_scale = -mean * scale;
+
+ for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
+ const int w_idx = ne10 > 1 ? i00 : 0;
+ const int b_idx = ne20 > 1 ? i00 : 0;
+ const float4 norm_x = mad(x[i00], (float4)scale, (float4)neg_mean_scale);
+ y[i00] = mad(norm_x, w[w_idx], b[b_idx]);
+ }
+}
diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
index 12dd5dd2e6287..18ff4e0b0c7cf 100644
--- a/ggml/src/ggml-sycl/ggml-sycl.cpp
+++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
@@ -4364,11 +4364,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
#endif
case GGML_OP_NORM:
- case GGML_OP_RMS_NORM:
return true;
case GGML_OP_L2_NORM:
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
+ case GGML_OP_RMS_NORM:
+ return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
case GGML_OP_SCALE:
return true;
case GGML_OP_CONT:
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index a5406f761274d..7189fc1cfa057 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -360,6 +360,13 @@ struct vk_fa_pipeline_state {
}
};
+enum shader_reduction_mode {
+ SHADER_REDUCTION_MODE_SHMEM,
+ SHADER_REDUCTION_MODE_HYBRID,
+ SHADER_REDUCTION_MODE_SUBGROUP,
+ SHADER_REDUCTION_MODE_COUNT,
+};
+
static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
@@ -386,14 +393,18 @@ struct vk_device_struct {
bool uma;
bool prefer_host_memory;
bool float_controls_rte_fp16;
- bool subgroup_add;
+ bool subgroup_arithmetic;
bool subgroup_shuffle;
+ bool subgroup_ballot;
+ bool subgroup_clustered;
bool multi_add;
bool add_rms_fusion;
uint32_t partials_binding_alignment;
bool integer_dot_product;
+ // 0: default, 1: force mmvq, -1: disable mmvq
+ int32_t mmvq_mode;
bool subgroup_size_control;
uint32_t subgroup_min_size;
@@ -451,12 +462,15 @@ struct vk_device_struct {
vk_pipeline pipeline_matmul_split_k_reduce;
vk_pipeline pipeline_quantize_q8_1;
+ vk_pipeline pipeline_quantize_q8_1_x4;
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
+ vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
+
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
@@ -565,6 +579,7 @@ struct vk_device_struct {
bool disable_fusion;
bool disable_host_visible_vidmem;
+ bool allow_sysmem_fallback;
#ifdef GGML_VULKAN_MEMORY_DEBUG
std::unique_ptr memory_logger;
@@ -1044,7 +1059,7 @@ struct vk_op_sum_rows_push_constants
uint32_t ne0_1mp, ne0_1L;
};
-vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) {
+static vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) {
uint32_t type_size = (uint32_t)ggml_type_size(src->type);
vk_op_sum_rows_push_constants p = {};
p.n_cols = (uint32_t)n_cols;
@@ -1355,6 +1370,7 @@ struct vk_instance_t {
PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {};
std::vector device_indices;
+ std::vector device_supports_membudget;
vk_device devices[GGML_VK_MAX_DEVICES];
};
@@ -1807,8 +1823,8 @@ static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_pr
return UINT32_MAX;
}
-static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
- VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")");
+static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list) {
+ VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")");
if (size > device->max_memory_allocation_size) {
throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit");
}
@@ -1835,42 +1851,27 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
- uint32_t memory_type_index = UINT32_MAX;
+ for (auto &req_flags : req_flags_list) {
+ uint32_t memory_type_index = find_properties(&mem_props, &mem_req, req_flags);
- memory_type_index = find_properties(&mem_props, &mem_req, req_flags);
- buf->memory_property_flags = req_flags;
+ if (memory_type_index == UINT32_MAX) {
+ continue;
+ }
+ buf->memory_property_flags = req_flags;
- if (memory_type_index == UINT32_MAX && fallback_flags) {
- memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags);
- buf->memory_property_flags = fallback_flags;
+ try {
+ buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
+ break;
+ } catch (const vk::SystemError& e) {
+ // loop and retry
+ }
}
- if (memory_type_index == UINT32_MAX) {
+ if (buf->device_memory == VK_NULL_HANDLE) {
device->device.destroyBuffer(buf->buffer);
throw vk::OutOfDeviceMemoryError("No suitable memory type found");
}
- try {
- buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
- } catch (const vk::SystemError& e) {
- if (buf->memory_property_flags != fallback_flags) {
- // Try again with fallback flags
- memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags);
- buf->memory_property_flags = fallback_flags;
-
- try {
- buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
- }
- catch (const vk::SystemError& e) {
- device->device.destroyBuffer(buf->buffer);
- throw e;
- }
- } else {
- // Out of Host/Device memory, clean up buffer
- device->device.destroyBuffer(buf->buffer);
- throw e;
- }
- }
buf->ptr = nullptr;
if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
@@ -1891,7 +1892,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor
static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
try {
- return ggml_vk_create_buffer(device, size, req_flags, fallback_flags);
+ return ggml_vk_create_buffer(device, size, {req_flags, fallback_flags});
} catch (const vk::SystemError& e) {
std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
@@ -1903,15 +1904,29 @@ static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
vk_buffer buf;
try {
if (device->prefer_host_memory) {
- buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
+ vk::MemoryPropertyFlagBits::eDeviceLocal});
} else if (device->uma) {
// Fall back to host memory type
- buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
+ buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal,
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
} else if (device->disable_host_visible_vidmem) {
- buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ if (device->allow_sysmem_fallback) {
+ buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal,
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
+ } else {
+ buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal});
+ }
} else {
// use rebar if available, otherwise fallback to device only visible memory
- buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ if (device->allow_sysmem_fallback) {
+ buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
+ vk::MemoryPropertyFlagBits::eDeviceLocal,
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
+ } else {
+ buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
+ vk::MemoryPropertyFlagBits::eDeviceLocal});
+ }
}
} catch (const vk::SystemError& e) {
std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
@@ -2089,10 +2104,11 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
const uint32_t warps = warptile[0] / warptile[10];
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
- const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
+ const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
+ const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;
- const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
+ const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
@@ -2176,8 +2192,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
+ const uint32_t mul_mat_subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
+ const uint32_t mul_mat_subgroup_size_8 = std::max(mul_mat_subgroup_size, 8u);
+ const uint32_t mul_mat_subgroup_size_16 = std::max(mul_mat_subgroup_size, 16u);
+ const uint32_t mul_mat_subgroup_size_32 = std::max(mul_mat_subgroup_size, 32u);
+
+ const bool subgroup_min_size_16 = (!device->subgroup_size_control && device->subgroup_size >= 16) ||
+ (device->subgroup_size_control && device->subgroup_max_size >= 16);
+
// mulmat
std::vector l_warptile, m_warptile, s_warptile,
+ l_warptile_id, m_warptile_id, s_warptile_id,
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
@@ -2214,7 +2239,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_mmq_wg_denoms_k = { 32, 64, 1 };
// spec constants and tile sizes for quant matmul_id
- l_warptile_mmqid = { 256, 128, 128, 16, 0, device->subgroup_size };
+ l_warptile_mmqid = { 256, 128, 128, 16, 1, device->subgroup_size };
m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
l_mmqid_wg_denoms = { 128, 128, 1 };
@@ -2248,9 +2273,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
+ l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
+ m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
+ s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
+
+ l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
+ m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
+ s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
+
// chip specific tuning
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
+ m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
}
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
@@ -2276,14 +2310,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
// Disable mul_mat_id if not enough shared memory is available
- if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true, t)) {
+ if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) {
device->mul_mat_id_s[i] = false;
device->mul_mat_id_m[i] = false;
device->mul_mat_id_l[i] = false;
- } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true, t)) {
+ } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t)) {
device->mul_mat_id_m[i] = false;
device->mul_mat_id_l[i] = false;
- } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true, t)) {
+ } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) {
device->mul_mat_id_l[i] = false;
}
}
@@ -2461,32 +2495,34 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
- CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
+ GGML_ASSERT(device->subgroup_ballot);
+
+ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
if (device->coopmat_bf16_support) {
- CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
}
#endif
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
#undef CREATE_MM
#undef CREATE_MM2
} else
@@ -2573,55 +2609,56 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
}
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+ GGML_ASSERT(device->subgroup_ballot);
+
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
if (device->coopmat_bf16_support) {
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
}
#endif
- CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-
- CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
#undef CREATE_MM2
#undef CREATE_MM
} else
#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->fp16) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
-#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
+#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) { \
@@ -2638,38 +2675,38 @@ static void ggml_vk_load_shaders(vk_device& device) {
} \
// Create 2 variants, {f16,f32} accumulator
-#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
- CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
- CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
-
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
-
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
-
- CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
-
- CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
+ CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
+ CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
+
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
@@ -2681,51 +2718,77 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
#endif
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
-
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
-
- CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-
- CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ } else {
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
+
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ }
#undef CREATE_MM2
#undef CREATE_MMQ
#undef CREATE_MM
} else {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
-#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
+#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
@@ -2735,34 +2798,34 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
-
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
-
- CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
-
- CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
@@ -2774,33 +2837,59 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
#endif
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
-
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
-
- CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-
- CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
- CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ } else {
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
+
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+ }
}
// reusing CREATE_MM from the fp32 path
if ((device->coopmat2 || device->coopmat_support)
@@ -2817,8 +2906,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_wg_denoms = { 64, 64, 1 };
s_wg_denoms = { 32, 32, 1 };
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
}
#undef CREATE_MM
@@ -2836,60 +2925,91 @@ static void ggml_vk_load_shaders(vk_device& device) {
rm_stdq = 2;
uint32_t rm_iq = 2 * rm_kq;
+ const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
+ // Ensure a subgroup size >= 16 is available
+ const bool use_subgroups16 = use_subgroups &&
+ (!device->subgroup_size_control && device->subgroup_size >= 16 ||
+ device->subgroup_size_control && device->subgroup_min_size <= 16 && device->subgroup_max_size >= 16);
+
+ const uint32_t subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control && device->subgroup_min_size <= 16 && device->subgroup_max_size >= 16) ? 16 : device->subgroup_size;
+ const uint32_t subgroup_size16 = std::max(subgroup_size, 16u);
+
+ const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0;
+ const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0;
+
for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) {
- uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_16 : (subgroup_size_16 * 4);
- uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? device->subgroup_size : (device->subgroup_size * 4);
+ const uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size : (subgroup_size * 4);
+ const uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size16 : (subgroup_size16 * 4);
+
+ const shader_reduction_mode reduc = (use_subgroups && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP :
+ (use_subgroups && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID :
+ SHADER_REDUCTION_MODE_SHMEM;
- const bool s = device->subgroup_add && device->architecture != vk_device_architecture::AMD_GCN;
+ const shader_reduction_mode reduc16 = (use_subgroups16 && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP :
+ (use_subgroups16 && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID :
+ SHADER_REDUCTION_MODE_SHMEM;
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[s], arr_dmmv_f32_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[s], arr_dmmv_f16_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[s], arr_dmmv_bf16_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[s], arr_dmmv_q4_0_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[s], arr_dmmv_q4_1_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[s], arr_dmmv_q5_0_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32", arr_dmmv_q5_1_f32_f32_len[s], arr_dmmv_q5_1_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[s], arr_dmmv_q8_0_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[s], arr_dmmv_q2_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[s], arr_dmmv_q3_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[s], arr_dmmv_q4_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32", arr_dmmv_q5_k_f32_f32_len[s], arr_dmmv_q5_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32", arr_dmmv_q6_k_f32_f32_len[s], arr_dmmv_q6_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32", arr_dmmv_iq1_s_f32_f32_len[s], arr_dmmv_iq1_s_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32", arr_dmmv_iq1_m_f32_f32_len[s], arr_dmmv_iq1_m_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32", arr_dmmv_iq2_xxs_f32_f32_len[s], arr_dmmv_iq2_xxs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32", arr_dmmv_iq2_xs_f32_f32_len[s], arr_dmmv_iq2_xs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32", arr_dmmv_iq2_s_f32_f32_len[s], arr_dmmv_iq2_s_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32", arr_dmmv_iq3_xxs_f32_f32_len[s], arr_dmmv_iq3_xxs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32", arr_dmmv_iq3_s_f32_f32_len[s], arr_dmmv_iq3_s_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[s], arr_dmmv_iq4_xs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[s], arr_dmmv_iq4_nl_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[s], arr_dmmv_mxfp4_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[s], arr_dmmv_f32_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[s], arr_dmmv_f16_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[s], arr_dmmv_bf16_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[s], arr_dmmv_q4_0_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[s], arr_dmmv_q4_1_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[s], arr_dmmv_q5_0_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32", arr_dmmv_q5_1_f16_f32_len[s], arr_dmmv_q5_1_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32", arr_dmmv_q8_0_f16_f32_len[s], arr_dmmv_q8_0_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32", arr_dmmv_q2_k_f16_f32_len[s], arr_dmmv_q2_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32", arr_dmmv_q3_k_f16_f32_len[s], arr_dmmv_q3_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32", arr_dmmv_q4_k_f16_f32_len[s], arr_dmmv_q4_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32", arr_dmmv_q5_k_f16_f32_len[s], arr_dmmv_q5_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32", arr_dmmv_q6_k_f16_f32_len[s], arr_dmmv_q6_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32", arr_dmmv_iq1_s_f16_f32_len[s], arr_dmmv_iq1_s_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32", arr_dmmv_iq1_m_f16_f32_len[s], arr_dmmv_iq1_m_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32", arr_dmmv_iq2_xxs_f16_f32_len[s], arr_dmmv_iq2_xxs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32", arr_dmmv_iq2_xs_f16_f32_len[s], arr_dmmv_iq2_xs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32", arr_dmmv_iq2_s_f16_f32_len[s], arr_dmmv_iq2_s_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32", arr_dmmv_iq3_xxs_f16_f32_len[s], arr_dmmv_iq3_xxs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32", arr_dmmv_iq3_s_f16_f32_len[s], arr_dmmv_iq3_s_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[s], arr_dmmv_iq4_xs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[s], arr_dmmv_iq4_nl_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[s], arr_dmmv_mxfp4_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32", arr_dmmv_q5_1_f32_f32_len[reduc], arr_dmmv_q5_1_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[reduc], arr_dmmv_q8_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[reduc16], arr_dmmv_q2_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[reduc16], arr_dmmv_q3_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[reduc16], arr_dmmv_q4_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32", arr_dmmv_q5_k_f32_f32_len[reduc16], arr_dmmv_q5_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32", arr_dmmv_q6_k_f32_f32_len[reduc16], arr_dmmv_q6_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32", arr_dmmv_iq1_s_f32_f32_len[reduc16], arr_dmmv_iq1_s_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32", arr_dmmv_iq1_m_f32_f32_len[reduc16], arr_dmmv_iq1_m_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32", arr_dmmv_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_iq2_xxs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32", arr_dmmv_iq2_xs_f32_f32_len[reduc16], arr_dmmv_iq2_xs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32", arr_dmmv_iq2_s_f32_f32_len[reduc16], arr_dmmv_iq2_s_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32", arr_dmmv_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_iq3_xxs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32", arr_dmmv_iq3_s_f32_f32_len[reduc16], arr_dmmv_iq3_s_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32", arr_dmmv_q5_1_f16_f32_len[reduc], arr_dmmv_q5_1_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32", arr_dmmv_q8_0_f16_f32_len[reduc], arr_dmmv_q8_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32", arr_dmmv_q2_k_f16_f32_len[reduc16], arr_dmmv_q2_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32", arr_dmmv_q3_k_f16_f32_len[reduc16], arr_dmmv_q3_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32", arr_dmmv_q4_k_f16_f32_len[reduc16], arr_dmmv_q4_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32", arr_dmmv_q5_k_f16_f32_len[reduc16], arr_dmmv_q5_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32", arr_dmmv_q6_k_f16_f32_len[reduc16], arr_dmmv_q6_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32", arr_dmmv_iq1_s_f16_f32_len[reduc16], arr_dmmv_iq1_s_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32", arr_dmmv_iq1_m_f16_f32_len[reduc16], arr_dmmv_iq1_m_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32", arr_dmmv_iq2_xxs_f16_f32_len[reduc16], arr_dmmv_iq2_xxs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32", arr_dmmv_iq2_xs_f16_f32_len[reduc16], arr_dmmv_iq2_xs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32", arr_dmmv_iq2_s_f16_f32_len[reduc16], arr_dmmv_iq2_s_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32", arr_dmmv_iq3_xxs_f16_f32_len[reduc16], arr_dmmv_iq3_xxs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32", arr_dmmv_iq3_s_f16_f32_len[reduc16], arr_dmmv_iq3_s_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
+
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+ if (device->integer_dot_product) {
+ const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
+ const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
+ }
+#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
}
}
@@ -2981,10 +3101,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
+
+ if (device->subgroup_clustered && device->subgroup_require_full_support) {
+ ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_subgroup_len, quantize_q8_1_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
+ ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
+ } else {
+ ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
+ }
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
- if (device->subgroup_add && device->subgroup_require_full_support) {
+ if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true);
} else {
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
@@ -3362,6 +3489,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
const char* GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM = getenv("GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM");
device->disable_host_visible_vidmem = GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM != nullptr;
+ const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK");
+ device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr;
+
bool fp16_storage = false;
bool fp16_compute = false;
bool maintenance4_support = false;
@@ -3500,11 +3630,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
}
device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
- device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
- (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
-
+ device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
+ device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered);
+
+ device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot);
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
@@ -3655,9 +3789,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
(subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) &&
subgroup_size_control_features.subgroupSizeControl;
- if (device->subgroup_size_control) {
- device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
- }
+ device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
#if defined(VK_KHR_cooperative_matrix)
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
@@ -3959,11 +4091,18 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
device->add_rms_fusion = !device->disable_fusion &&
- device->subgroup_add &&
+ device->subgroup_arithmetic &&
device->vendor_id != VK_VENDOR_ID_INTEL;
device->partials_binding_alignment =
std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
+ device->mmvq_mode = 0;
+ if (getenv("GGML_VK_DISABLE_MMVQ")) {
+ device->mmvq_mode = -1;
+ } else if (getenv("GGML_VK_FORCE_MMVQ")) {
+ device->mmvq_mode = 1;
+ }
+
return device;
}
@@ -4202,15 +4341,16 @@ static void ggml_vk_instance_init() {
vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT");
vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT");
vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT");
-
}
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
+ std::vector devices = vk_instance.instance.enumeratePhysicalDevices();
+
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES");
if (devices_env != nullptr) {
- size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
+ size_t num_available_devices = devices.size();
std::string devices(devices_env);
std::replace(devices.begin(), devices.end(), ',', ' ');
@@ -4225,8 +4365,6 @@ static void ggml_vk_instance_init() {
vk_instance.device_indices.push_back(tmp);
}
} else {
- std::vector devices = vk_instance.instance.enumeratePhysicalDevices();
-
// If no vulkan devices are found, return early
if (devices.empty()) {
GGML_LOG_INFO("ggml_vulkan: No devices found.\n");
@@ -4331,6 +4469,19 @@ static void ggml_vk_instance_init() {
GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size());
for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
+ vk::PhysicalDevice vkdev = devices[vk_instance.device_indices[i]];
+ std::vector extensionprops = vkdev.enumerateDeviceExtensionProperties();
+
+ bool membudget_supported = false;
+ for (const auto & ext : extensionprops) {
+ if (strcmp(VK_EXT_MEMORY_BUDGET_EXTENSION_NAME, ext.extensionName) == 0) {
+ membudget_supported = true;
+ break;
+ }
+ }
+
+ vk_instance.device_supports_membudget.push_back(membudget_supported);
+
ggml_vk_print_gpu_info(i);
}
}
@@ -4477,9 +4628,22 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols, uint32_t m, uint32_t k) {
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
- GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16);
+ GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16 || b_type == GGML_TYPE_Q8_1);
GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols);
+ if (b_type == GGML_TYPE_Q8_1) {
+ switch (a_type) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ break;
+ default:
+ return nullptr;
+ }
+ }
+
switch (a_type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
@@ -4511,7 +4675,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
// heuristic to choose workgroup size
uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
- if (ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
+ if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
// Prefer larger workgroups when M is small, to spread the work out more
// and keep more SMs busy.
// q6_k seems to prefer small workgroup size even for "medium" values of M.
@@ -4526,6 +4690,13 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
}
}
+ if (b_type == GGML_TYPE_Q8_1) {
+ if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
+ dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
+ }
+ return ctx->device->pipeline_dequant_mul_mat_vec_q8_1_f32[dmmv_wg][a_type][num_cols-1];
+ }
+
return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[dmmv_wg][a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[dmmv_wg][a_type][num_cols-1];
}
@@ -4595,7 +4766,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
}
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
- VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
+ VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()");
GGML_ASSERT(b_type == GGML_TYPE_F32);
switch (a_type) {
@@ -4698,8 +4869,8 @@ static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_
static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
vk_buffer buf = ggml_vk_create_buffer(device, size,
- vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
- vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
+ {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n",
@@ -5508,20 +5679,20 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_sync_buffers(ctx, subctx);
}
-static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
+static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type, bool use_x4_blocks) {
switch(type) {
case GGML_TYPE_Q8_1:
- return ctx->device->pipeline_quantize_q8_1;
+ return use_x4_blocks ? ctx->device->pipeline_quantize_q8_1_x4 : ctx->device->pipeline_quantize_q8_1;
default:
std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
GGML_ABORT("fatal error");
}
}
-static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) {
+static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne, bool use_x4_blocks = false) {
VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
- vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
+ vk_pipeline pipeline = use_x4_blocks ? ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true) : ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ne}, { ne, 1, 1 });
ggml_vk_sync_buffers(ctx, subctx);
@@ -5641,12 +5812,15 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
if (quantize_y) {
- to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
+ to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
}
if (dryrun) {
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
- const uint64_t y_sz_upd = y_sz * ne12 * ne13;
+ uint64_t y_sz_upd = y_sz * ne12 * ne13;
+ if (quantize_y) {
+ y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
+ }
const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
if (
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
@@ -5712,7 +5886,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
} else if (quantize_y) {
d_Y = ctx->prealloc_y;
- GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1));
+ GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144);
} else {
d_Y = d_Qy;
y_buf_offset = qy_buf_offset;
@@ -5724,11 +5898,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_sync_buffers(ctx, subctx);
}
}
- if (y_non_contig || quantize_y) {
- if (ctx->prealloc_y_need_sync) {
- ggml_vk_sync_buffers(ctx, subctx);
- }
- }
if (x_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
@@ -5740,6 +5909,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (y_non_contig) {
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@@ -5748,7 +5920,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (quantize_y) {
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
- ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
+ ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
@@ -5765,10 +5940,15 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
}
+ uint32_t y_sz_total = y_sz * ne12 * ne13;
+ if (quantize_y) {
+ y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
+ }
+
// compute
ggml_vk_matmul(
ctx, subctx, pipeline,
- { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
+ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
{ d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
ne01, ne11, ne10,
ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
@@ -5783,6 +5963,51 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
}
}
+// Device tuning
+static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_t n, uint32_t k, ggml_type src0_type) {
+ if (device->mmvq_mode == 1) {
+ return true;
+ } else if (device->mmvq_mode == -1) {
+ return false;
+ }
+
+ // MMVQ is generally good for batches
+ if (n > 1) {
+ return true;
+ }
+
+ switch (device->vendor_id) {
+ case VK_VENDOR_ID_NVIDIA:
+ switch (src0_type) {
+ case GGML_TYPE_Q8_0:
+ return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
+ default:
+ return true;
+ }
+ case VK_VENDOR_ID_AMD:
+ switch (src0_type) {
+ case GGML_TYPE_Q8_0:
+ return device->architecture == vk_device_architecture::AMD_GCN;
+ default:
+ return true;
+ }
+ case VK_VENDOR_ID_INTEL:
+ switch (src0_type) {
+ // From tests on A770 Linux, may need more tuning
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q5_1:
+ return false;
+ default:
+ return true;
+ }
+ default:
+ return true;
+ }
+
+ GGML_UNUSED(m);
+ GGML_UNUSED(k);
+}
+
static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@@ -5837,22 +6062,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
-
- const bool qx_needs_dequant = x_non_contig;
- const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
-
- // Not implemented
- GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
-
- const uint64_t x_ne = ne01 * ne00;
- const uint64_t y_ne = ne11 * ne10;
- const uint64_t d_ne = ne11 * ne01;
-
- const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
- const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
- const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
- const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
- const uint64_t d_sz = sizeof(float) * d_ne;
+ bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11, ne10, src0->type);
vk_pipeline to_fp16_vk_0 = nullptr;
vk_pipeline to_fp16_vk_1 = nullptr;
@@ -5864,14 +6074,47 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
} else {
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
}
- vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11, ne20, ne00);
+
+ // Check for mmq first
+ vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, GGML_TYPE_Q8_1, ne11, ne20, ne00) : nullptr;
+ vk_pipeline to_q8_1 = nullptr;
+
+ if (dmmv == nullptr) {
+ // Fall back to f16 dequant mul mat
+ dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11, ne20, ne00);
+ quantize_y = false;
+ }
+
+ if (quantize_y) {
+ to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
+ }
+
+ const bool qx_needs_dequant = x_non_contig;
+ const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
+
+ // Not implemented
+ GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
+
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
GGML_ASSERT(dmmv != nullptr);
+ const uint64_t x_ne = ne01 * ne00;
+ const uint64_t y_ne = ne11 * ne10;
+ const uint64_t d_ne = ne11 * ne01;
+
+ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
+ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
+ const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
+ const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
+ const uint64_t d_sz = sizeof(float) * d_ne;
+
if (dryrun) {
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
- const uint64_t y_sz_upd = y_sz * ne12 * ne13;
+ uint64_t y_sz_upd = y_sz * ne12 * ne13;
+ if (quantize_y) {
+ y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
+ }
if (
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
@@ -5880,7 +6123,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
ctx->prealloc_size_x = x_sz_upd;
}
- if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
+ if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
ctx->prealloc_size_y = y_sz_upd;
}
@@ -5891,6 +6134,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (qy_needs_dequant) {
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
}
+ if (quantize_y) {
+ ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
+ }
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
return;
}
@@ -5921,6 +6167,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
}
if (qy_needs_dequant) {
d_Y = ctx->prealloc_y;
+ } else if (quantize_y) {
+ d_Y = ctx->prealloc_y;
+ GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144);
} else {
d_Y = d_Qy;
y_buf_offset = qy_buf_offset;
@@ -5931,14 +6180,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (ctx->prealloc_x_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
- }
- if (y_non_contig) {
- if (ctx->prealloc_y_need_sync) {
- ggml_vk_sync_buffers(ctx, subctx);
- }
- }
- if (x_non_contig) {
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
}
@@ -5946,11 +6188,25 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
}
+ if (quantize_y) {
+ if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
+ ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
+ ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true);
+ ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
+ ctx->prealloc_y_last_tensor_used = src1;
+ }
+ }
// For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride
uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01;
@@ -5975,6 +6231,12 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
groups_x = CEIL_DIV(groups_x, groups_z);
}
+ // TODO: Clean up this whole sz * ne_2 * ne_3 thing, it hasn't been necessary for a long time
+ uint32_t y_sz_total = y_sz * ne12 * ne13;
+ if (quantize_y) {
+ y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
+ }
+
// compute
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
@@ -5982,13 +6244,13 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
- { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
+ { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz_total }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
if (x_non_contig) {
ctx->prealloc_x_need_sync = true;
}
- if (y_non_contig) {
+ if (y_non_contig || quantize_y) {
ctx->prealloc_y_need_sync = true;
}
}
@@ -6213,7 +6475,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t nei0 = ids->ne[0];
const uint64_t nei1 = ids->ne[1];
- GGML_ASSERT(nei0 * nei1 <= 4096);
const uint32_t nbi1 = ids->nb[1];
const uint32_t nbi2 = ids->nb[2];
@@ -6379,11 +6640,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_sync_buffers(ctx, subctx);
}
}
- if (y_non_contig) {
- if (ctx->prealloc_y_need_sync) {
- ggml_vk_sync_buffers(ctx, subctx);
- }
- }
if (x_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
@@ -6396,6 +6652,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (y_non_contig) {
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@@ -6593,11 +6852,6 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
ggml_vk_sync_buffers(ctx, subctx);
}
}
- if (y_non_contig) {
- if (ctx->prealloc_y_need_sync) {
- ggml_vk_sync_buffers(ctx, subctx);
- }
- }
if (x_non_contig) {
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
@@ -6607,6 +6861,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
@@ -6653,37 +6910,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
} else {
- // Split based on number of ids, to fit in shared memory
- const uint32_t nei0 = (uint32_t)src2->ne[0];
- const uint32_t nei1 = (uint32_t)src2->ne[1];
-
- GGML_ASSERT(nei0 <= 4096);
- const uint32_t split_size = std::min(nei1, 4096u / nei0);
-
- if (split_size == nei1) {
- ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
- } else {
- ggml_tensor src1_copy = *src1;
- ggml_tensor src2_copy = *src2;
- ggml_tensor dst_copy = *dst;
-
- for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
- const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
-
- src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
- src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
- dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
-
- src1_copy.ne[2] = n_tokens;
- src2_copy.ne[1] = n_tokens;
- dst_copy.ne[2] = n_tokens;
-
- ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
- // invalidate cached prealloc_y, can't cache based on the copy of the ggml_tensor
- ctx->prealloc_y_last_pipeline_used = {};
- ctx->prealloc_y_last_tensor_used = nullptr;
- }
- }
+ ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
}
}
@@ -7806,6 +8033,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
break;
case GGML_OP_GET_ROWS:
elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
+ elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
+ elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
break;
case GGML_OP_ARGSORT:
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
@@ -9142,7 +9371,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
if (ctx->prealloc_split_k != nullptr) {
ggml_vk_destroy_buffer(ctx->prealloc_split_k);
}
- ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal});
}
}
@@ -9152,9 +9381,9 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
ggml_pipeline_allocate_descriptor_sets(ctx);
- vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
- vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
- vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
+ vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
+ vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne);
Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne);
@@ -9380,8 +9609,8 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
float * x = (float *) malloc(x_sz);
void * qx = malloc(qx_sz);
- vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
- vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
+ vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, {vk::MemoryPropertyFlagBits::eDeviceLocal});
float * x_ref = (float *) malloc(x_sz);
ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16);
@@ -9486,8 +9715,8 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
// float * x = (float *) malloc(x_sz);
// block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz);
// block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);
-// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
-// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
+// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
+// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
//
// for (size_t i = 0; i < ne; i++) {
// x[i] = rand() / (float)RAND_MAX;
@@ -9634,10 +9863,10 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
float * x = (float *) malloc(x_sz);
float * y = (float *) malloc(y_sz);
void * qx = malloc(qx_sz);
- vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
- vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
- vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
- vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
+ vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
+ vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
+ vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
float * d = (float *) malloc(d_sz);
float * d_chk = (float *) malloc(d_sz);
@@ -9664,7 +9893,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
if (ctx->prealloc_split_k != nullptr) {
ggml_vk_destroy_buffer(ctx->prealloc_split_k);
}
- ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal});
}
}
if (mmq) {
@@ -10194,12 +10423,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
if (need_sync) {
- VK_LOG_DEBUG("node_idx=" << i << " sync");
ctx->unsynced_nodes_written.clear();
ctx->unsynced_nodes_read.clear();
ggml_vk_sync_buffers(ctx, compute_ctx);
- } else {
- VK_LOG_DEBUG("node_idx=" << i << " unsynced");
}
// Add the last fused node and all fused source nodes to the unsynchronized list.
const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
@@ -11441,15 +11667,29 @@ void ggml_backend_vk_get_device_description(int device, char * description, size
void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
GGML_ASSERT(device < (int) vk_instance.device_indices.size());
+ GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size());
vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
+ vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops;
+ vk::PhysicalDeviceMemoryProperties2 memprops = {};
+ bool membudget_supported = vk_instance.device_supports_membudget[device];
+
+ if (membudget_supported) {
+ memprops.pNext = &budgetprops;
+ }
+ vkdev.getMemoryProperties2(&memprops);
- vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
+ for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) {
+ const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i];
- for (const vk::MemoryHeap& heap : memprops.memoryHeaps) {
if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
*total = heap.size;
- *free = heap.size;
+
+ if (membudget_supported && i < budgetprops.heapUsage.size()) {
+ *free = budgetprops.heapBudget[i] - budgetprops.heapUsage[i];
+ } else {
+ *free = heap.size;
+ }
break;
}
}
@@ -11975,16 +12215,13 @@ static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions) {
#ifdef __APPLE__
- bool portability_enumeration_ext = false;
// Check for portability enumeration extension for MoltenVK support
for (const auto& properties : instance_extensions) {
if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
return true;
}
}
- if (!portability_enumeration_ext) {
- std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
- }
+ std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
#endif
return false;
@@ -12241,7 +12478,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
} else if (tensor->op == GGML_OP_CONCAT) {
tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
} else if (tensor->op == GGML_OP_UPSCALE) {
- tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
+ tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
} else if (tensor->op == GGML_OP_SCALE) {
const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
@@ -12480,11 +12717,9 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph *
if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {
return;
}
- bool fused_rms_norm_mul = false;
if (ctx->num_additional_fused_ops == 1 &&
tensor->op == GGML_OP_RMS_NORM &&
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
- fused_rms_norm_mul = true;
tensor = cgraph->nodes[tensor_idx + 1];
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
index d40848e15fe97..482445c6fea2c 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
@@ -334,6 +334,9 @@ void main() {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] *= Lfrcp[r];
+#if defined(ACC_TYPE_MAX)
+ Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
+#endif
}
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
index 97c2a54129709..63b32171b0c07 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
@@ -373,6 +373,9 @@ void main() {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
+#if defined(ACC_TYPE_MAX)
+ Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
+#endif
}
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
index 77ae5ff01d03e..ab647e9bc8b68 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
@@ -283,6 +283,10 @@ void main() {
O = Ldiag*O;
+#if defined(ACC_TYPE_MAX)
+ [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
coopmat O_D = coopmat(O);
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
index 76ef4b6dfb571..06e83822fe326 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
@@ -111,6 +111,10 @@ void main() {
}
}
O *= L;
+
+ const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
+ O = clamp(O, -FLT_MAX, FLT_MAX);
+
data_d[iq3 * D * N + D * n + d] = O;
}
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp
index ee6b86a18ddf2..7ef75cd7a492e 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp
@@ -7,27 +7,36 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint i00 = gl_GlobalInvocationID.x;
- const uint i10 = gl_GlobalInvocationID.y;
- const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
- const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
if (i00 >= p.ne00) {
return;
}
- const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
+ uint gid_z = gl_GlobalInvocationID.z;
+ while (gid_z < p.ne11 * p.ne12) {
+ uint gid_y = gl_GlobalInvocationID.y;
+ while (gid_y < p.ne10) {
+ const uint i10 = gid_y;
+ const uint i11 = gid_z / p.ne12;
+ const uint i12 = gid_z % p.ne12;
- const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
- const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
+ const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
+
+ const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
+ const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
#if defined(DATA_A_BF16)
- FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
+ FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
#else
- FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
+ FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
#endif
#ifndef OPTIMIZATION_ERROR_WORKAROUND
- data_d[d_offset + i00] = D_TYPE(v);
+ data_d[d_offset + i00] = D_TYPE(v);
#else
- data_d[d_offset + i00] = D_TYPE(v);
+ data_d[d_offset + i00] = D_TYPE(v);
#endif
+ gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+ }
+ gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
+ }
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp
index cfd645a38a8ba..339f905fc7566 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp
@@ -10,9 +10,6 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint i00 = (gl_GlobalInvocationID.x)*2;
- const uint i10 = gl_GlobalInvocationID.y;
- const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
- const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
@@ -22,20 +19,33 @@ void main() {
return;
}
- const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
+ uint gid_z = gl_GlobalInvocationID.z;
+ while (gid_z < p.ne11 * p.ne12) {
+ uint gid_y = gl_GlobalInvocationID.y;
+ while (gid_y < p.ne10) {
+ const uint i10 = gid_y;
+ const uint i11 = gid_z / p.ne12;
+ const uint i12 = gid_z % p.ne12;
- const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
- const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
+ const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
- const uint ib = a_offset + i00/QUANT_K; // block index
- const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
- const uint iybs = i00 - i00%QUANT_K; // dst block start index
- const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
+ const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
+ const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
- vec2 v = dequantize(ib, iqs, 0);
- const vec2 dm = get_dm(ib, 0);
- v = v * dm.x + dm.y;
+ const uint ib = a_offset + i00/QUANT_K; // block index
+ const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
+ const uint iybs = i00 - i00%QUANT_K; // dst block start index
+ const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
- data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
- data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
+ vec2 v = dequantize(ib, iqs, 0);
+ const vec2 dm = get_dm(ib, 0);
+ v = v * dm.x + dm.y;
+
+ data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
+ data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
+
+ gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+ }
+ gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
+ }
}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp
index b93e9948f7fc9..f761391eaed71 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp
@@ -1,7 +1,8 @@
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_8bit_storage : require
-#if USE_SUBGROUP_ADD
+
+#if USE_SUBGROUP_ADD || USE_SUBGROUP_ADD_NO_SHMEM
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_arithmetic : require
#endif
@@ -12,10 +13,19 @@
#include "types.comp"
+#ifndef MMQ
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+#else
+layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
+#endif
+
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
+#ifdef B_TYPE_VEC2
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
+#endif
+#ifdef B_TYPE_VEC4
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
+#endif
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
@@ -92,6 +102,23 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (constant_id = 1) const uint NUM_ROWS = 1;
layout (constant_id = 2) const uint NUM_COLS = 1;
+#ifdef USE_SUBGROUP_ADD_NO_SHMEM
+void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ temp[j][n] = subgroupAdd(temp[j][n]);
+ }
+ }
+
+ if (tid == 0) {
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
+ }
+ }
+ }
+}
+#else
shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
@@ -152,3 +179,4 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
}
#endif
}
+#endif
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp
new file mode 100644
index 0000000000000..8fb314fa0aaa4
--- /dev/null
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp
@@ -0,0 +1,140 @@
+#version 450
+
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+#extension GL_EXT_integer_dot_product : require
+
+#define MMQ
+#define B_TYPE block_q8_1_x4
+
+#include "mul_mat_vec_base.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+#define K_PER_ITER 8
+
+#include "mul_mmq_funcs.comp"
+
+uint a_offset, b_offset, d_offset;
+
+int32_t cache_b_qs[2];
+vec2 cache_b_ds;
+
+void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
+
+ // Preload data_b block
+ const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
+ const uint b_qs_idx = tid % 4;
+ const uint b_block_idx_outer = b_block_idx / 4;
+ const uint b_block_idx_inner = b_block_idx % 4;
+ cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
+
+#if QUANT_R == 2
+ cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
+ cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
+#else
+ cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
+ cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
+#endif
+
+ uint ibi = first_row*p.ncols;
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
+ ibi += p.ncols;
+
+ int32_t q_sum = 0;
+#if QUANT_R == 2
+ const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
+ q_sum += dotPacked4x8EXT(data_a_qs.x,
+ cache_b_qs[0]);
+ q_sum += dotPacked4x8EXT(data_a_qs.y,
+ cache_b_qs[1]);
+#else
+ int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
+ q_sum += dotPacked4x8EXT(data_a_qs,
+ cache_b_qs[0]);
+ data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1);
+ q_sum += dotPacked4x8EXT(data_a_qs,
+ cache_b_qs[1]);
+#endif
+
+#if QUANT_AUXF == 1
+ temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
+#else
+ temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
+#endif
+ }
+ }
+}
+
+void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
+ const uint tid = gl_LocalInvocationID.x;
+
+ get_offsets(a_offset, b_offset, d_offset);
+ a_offset /= QUANT_K;
+ b_offset /= QUANT_K_Q8_1;
+
+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ temp[j][n] = FLOAT_TYPE(0.0f);
+ }
+ }
+
+ uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
+ if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
+ num_iters++;
+ }
+ int unroll_count = 4;
+ uint unrolled_iters = num_iters & ~(unroll_count - 1);
+
+ uint i = 0;
+ while (i < unrolled_iters) {
+ // Manually partially unroll the loop
+ [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
+ i++;
+ }
+ }
+
+ unroll_count = 2;
+ unrolled_iters = num_iters & ~(unroll_count - 1);
+
+#if K_PER_ITER == 2
+ if ((p.ncols & 1) != 0 &&
+ unrolled_iters == num_iters &&
+ unrolled_iters > 0) {
+ unrolled_iters -= unroll_count;
+ }
+#endif
+
+ while (i < unrolled_iters) {
+ // Manually partially unroll the loop
+ [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
+ i++;
+ }
+ }
+ while (i < num_iters) {
+ iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
+ i++;
+ }
+
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
+}
+
+void main() {
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
+
+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
+ if (first_row + NUM_ROWS <= p.stride_d) {
+ compute_outputs(first_row, NUM_ROWS);
+ } else {
+ if (first_row >= p.stride_d) {
+ return;
+ }
+ compute_outputs(first_row, p.stride_d - first_row);
+ }
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
index d57cc6bdec5df..7e10e99e9e877 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
@@ -17,6 +17,9 @@
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
+#endif
+
+#if defined(COOPMAT) || defined(MUL_MAT_ID_USE_SUBGROUPS)
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#endif
@@ -106,11 +109,13 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef MUL_MAT_ID
-shared u16vec2 row_ids[4096];
+shared u16vec2 row_ids[BN];
uint _ne1;
-#ifdef COOPMAT
+
+#ifdef MUL_MAT_ID_USE_SUBGROUPS
shared uvec4 ballots_sh[NUM_WARPS];
-void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
+
+void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0);
@@ -160,15 +165,18 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
- if (in_range && id == expert_idx) {
- row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
+ if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
+ row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
}
_ne1 += total;
iter &= 15;
+ if (_ne1 >= (ic + 1) * BN) {
+ break;
+ }
}
barrier();
}
-#endif
+#endif // MUL_MAT_ID_USE_SUBGROUPS
#endif // MUL_MAT_ID
#ifdef COOPMAT
@@ -235,18 +243,20 @@ void main() {
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID
-#ifdef COOPMAT
+#ifdef MUL_MAT_ID_USE_SUBGROUPS
if (bitCount(p.nei0) == 1) {
- load_row_ids(expert_idx, true);
+ load_row_ids(expert_idx, true, ic);
} else {
- load_row_ids(expert_idx, false);
+ load_row_ids(expert_idx, false, ic);
}
#else
_ne1 = 0;
- for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
- for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
+ for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
+ for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
- row_ids[_ne1] = u16vec2(ii0, ii1);
+ if (_ne1 >= ic * BN) {
+ row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
+ }
_ne1++;
}
}
@@ -792,7 +802,7 @@ void main() {
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
#if LOAD_VEC_B == 8
#ifdef MUL_MAT_ID
- const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
+ const u16vec2 row_idx = row_ids[loadc_b + l];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
#else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
@@ -808,7 +818,7 @@ void main() {
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
#elif LOAD_VEC_B == 4
#ifdef MUL_MAT_ID
- const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
+ const u16vec2 row_idx = row_ids[loadc_b + l];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
#else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
@@ -827,7 +837,7 @@ void main() {
#else
const uint row_i = ic * BN + loadc_b + l;
if (row_i < _ne1 && block + loadr_b < end_k) {
- const u16vec2 row_idx = row_ids[row_i];
+ const u16vec2 row_idx = row_ids[loadc_b + l];
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
} else {
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
@@ -881,6 +891,20 @@ void main() {
barrier();
}
+#if defined(ACC_TYPE_MAX)
+#ifdef COOPMAT
+ [[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) {
+ [[unroll]] for (uint i = 0; i < sums[j].length(); ++i) {
+ sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
+ }
+ }
+#else
+ [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
+ sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
+ }
+#endif
+#endif
+
const uint dr = ir * BM + warp_r * WM;
const uint dc = ic * BN + warp_c * WN;
@@ -898,7 +922,7 @@ void main() {
const uint row_i = dc + cm_col * TN + col + store_c;
if (row_i >= _ne1) break;
- const u16vec2 row_idx = row_ids[row_i];
+ const u16vec2 row_idx = row_ids[row_i - ic * BN];
if (dr + cm_row * TM + store_r < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
@@ -948,7 +972,7 @@ void main() {
const uint row_i = dc_warp + cc;
if (row_i >= _ne1) break;
- const u16vec2 row_idx = row_ids[row_i];
+ const u16vec2 row_idx = row_ids[row_i - ic * BN];
#endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
#ifdef MUL_MAT_ID
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
index 4d16eb0791ddc..69ac38fd4196c 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
@@ -93,7 +93,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
-shared u16vec4 row_ids[4096];
+shared u16vec4 row_ids[BN];
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
B_TYPE b[];
@@ -111,7 +111,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i
return B_TYPE(0.0);
}
- const u16vec4 row_idx = row_ids[row_i];
+ const u16vec4 row_idx = row_ids[row_i & (BN - 1)];
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
return ret;
@@ -123,14 +123,14 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
uint dc = ic * BN + c;
if (dr < p.M && dc < _ne1) {
- uint row_i = dc;
+ uint row_i = c;
const u16vec4 row_idx = row_ids[row_i];
data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
}
return elem;
}
-void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
+void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0);
@@ -180,11 +180,14 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
- if (in_range && id == expert_idx) {
- row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
+ if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
+ row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
}
_ne1 += total;
iter &= 15;
+ if (_ne1 >= (ic + 1) * BN) {
+ break;
+ }
}
barrier();
}
@@ -218,9 +221,9 @@ void main() {
#ifdef MUL_MAT_ID
if (bitCount(p.nei0) == 1) {
- load_row_ids(expert_idx, true);
+ load_row_ids(expert_idx, true, ic);
} else {
- load_row_ids(expert_idx, false);
+ load_row_ids(expert_idx, false, ic);
}
// Workgroup has no work
@@ -346,6 +349,10 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
+#if defined(ACC_TYPE_MAX)
+ [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
coopmat mat_d = coopmat(sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
@@ -385,6 +392,10 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
+#if defined(ACC_TYPE_MAX)
+ [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
coopmat mat_d = coopmat(sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
@@ -425,6 +436,10 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
+#if defined(ACC_TYPE_MAX)
+ [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
coopmat mat_d = coopmat(sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
@@ -441,18 +456,111 @@ void main() {
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
- coopmat sum;
- sum = coopmat(0.0);
-
uint k_iters = (end_k - start_k + BK - 1) / BK;
fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
+ store_scales(tid);
+
+#ifdef MUL_MAT_ID
+ if (enable_smaller_matrices && ic * BN + BNover4 >= _ne1) {
+ coopmat sum;
+ sum = coopmat(0.0);
+
+ [[dont_unroll]]
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
+
+ if ((block_k % QUANT_K) == 0) {
+ store_scales(tid);
+ }
+ if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
+ }
+
+ if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
+ coopmat mat_a;
+ coopmat mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ } else {
+ coopmat mat_a;
+ coopmat mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ }
+ }
+#if defined(ACC_TYPE_MAX)
+ [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
+ // Convert from ACC_TYPE to D_TYPE
+ coopmat mat_d;
+ mat_d = coopmat(sum);
+
+ // Call callback to store each element, remapping row through shared memory
+ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
+ return;
+ }
+ if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) {
+ coopmat sum;
+ sum = coopmat(0.0);
+
+ [[dont_unroll]]
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
+
+ if ((block_k % QUANT_K) == 0) {
+ store_scales(tid);
+ }
+ if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
+ }
+
+ if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
+ coopmat mat_a;
+ coopmat mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ } else {
+ coopmat mat_a;
+ coopmat mat_b;
+
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
+
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
+ }
+ }
+#if defined(ACC_TYPE_MAX)
+ [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
+ // Convert from ACC_TYPE to D_TYPE
+ coopmat mat_d;
+ mat_d = coopmat(sum);
+
+ // Call callback to store each element, remapping row through shared memory
+ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
+ return;
+ }
+#endif
+ coopmat sum;
+ sum = coopmat