From d8bebaed2e88f9fa0886620a2269eb6c17893fbc Mon Sep 17 00:00:00 2001 From: Daniel Tang Date: Sat, 17 May 2025 19:06:26 -0400 Subject: [PATCH 01/24] ggml : Fix missing backtrace on Linux (ggml/1228) * Modern Linux defaults /proc/sys/kernel/yama/ptrace_scope to 1 * Fixed lldb attach * Simplify by having the child do ggml_print_backtrace_symbols --- ggml/src/ggml.c | 59 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 8a6546240f4..d0ed270edfc 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -70,6 +70,9 @@ float ggml_table_f32_f16[1 << 16]; #include #include #include +#if defined(__linux__) +#include +#endif #if defined(__ANDROID__) #include @@ -133,10 +136,36 @@ static void ggml_print_backtrace(void) { if (GGML_NO_BACKTRACE) { return; } - char attach[32]; - snprintf(attach, sizeof(attach), "attach %d", getpid()); - int pid = fork(); - if (pid == 0) { +#if defined(__linux__) + FILE * f = fopen("/proc/self/status", "r"); + size_t size = 0; + char * line = NULL; + ssize_t length = 0; + while ((length = getline(&line, &size, f)) > 0) { + if (!strncmp(line, "TracerPid:", sizeof("TracerPid:") - 1) && + (length != sizeof("TracerPid:\t0\n") - 1 || line[length - 2] != '0')) { + // Already being debugged, and the breakpoint is the later abort() + free(line); + fclose(f); + return; + } + } + free(line); + fclose(f); + int lock[2] = { -1, -1 }; + (void) !pipe(lock); // Don't start gdb until after PR_SET_PTRACER +#endif + const int parent_pid = getpid(); + const int child_pid = fork(); + if (child_pid < 0) { // error + return; + } else if (child_pid == 0) { // child + char attach[32]; + snprintf(attach, sizeof(attach), "attach %d", parent_pid); +#if defined(__linux__) + close(lock[1]); + (void) !read(lock[0], lock, 1); +#endif // try gdb execlp("gdb", "gdb", "--batch", "-ex", "set style enabled on", @@ -149,18 +178,18 @@ static void ggml_print_backtrace(void) { execlp("lldb", "lldb", "--batch", "-o", "bt", "-o", "quit", - "-p", attach, + "-p", &attach[sizeof("attach ") - 1], (char *) NULL); - exit(EXIT_FAILURE); - } else { - int wstatus; - waitpid(pid, &wstatus, 0); - if (WIFEXITED(wstatus)) { - if (WEXITSTATUS(wstatus) == EXIT_FAILURE) { - // gdb failed, fallback to backtrace_symbols - ggml_print_backtrace_symbols(); - } - } + // gdb failed, fallback to backtrace_symbols + ggml_print_backtrace_symbols(); + _Exit(0); + } else { // parent +#if defined(__linux__) + prctl(PR_SET_PTRACER, child_pid); + close(lock[1]); + close(lock[0]); +#endif + waitpid(child_pid, NULL, 0); } } #else From 6c2e5c77c264d0b3a895334961ef55ac85ac7ab6 Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Sun, 18 May 2025 18:30:13 -0700 Subject: [PATCH 02/24] ggml : fix apple OS check in ggml_print_backtrace (ggml/1229) --- ggml/src/ggml.c | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d0ed270edfc..d48adb9afb8 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -64,8 +64,10 @@ // precomputed f32 table for f16 (256 KB) (ggml-impl.h) float ggml_table_f32_f16[1 << 16]; -#if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \ - (!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH)) +#if defined(__linux__) || \ + defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \ + (defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH) + #include #include #include From f6e279977f51cbb8ab5b3f0423e793ad1b03dc7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 19 May 2025 09:33:35 +0200 Subject: [PATCH 03/24] mnist: fix segmentation fault (ggml/1227) --- ggml/include/ggml-opt.h | 2 ++ ggml/src/ggml-opt.cpp | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/ggml/include/ggml-opt.h b/ggml/include/ggml-opt.h index da0c24b46fe..74ec080a055 100644 --- a/ggml/include/ggml-opt.h +++ b/ggml/include/ggml-opt.h @@ -128,6 +128,8 @@ extern "C" { // set gradients to zero, initilize loss, and optionally reset the optimizer GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer); + GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically + // get underlying tensors that store data // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor diff --git a/ggml/src/ggml-opt.cpp b/ggml/src/ggml-opt.cpp index 58d77578f45..a3c82d67577 100644 --- a/ggml/src/ggml-opt.cpp +++ b/ggml/src/ggml-opt.cpp @@ -576,6 +576,10 @@ void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) { } } +bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) { + return opt_ctx->static_graphs; +} + struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) { return opt_ctx->inputs; } @@ -842,6 +846,7 @@ void ggml_opt_epoch( int64_t idata_split, ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval) { + GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && "ggml_opt_epoch requires static graphs"); struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx); struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); struct ggml_tensor * data = ggml_opt_dataset_data(dataset); From 3970f74c53fbeff941f8643dad2690104abd6fa9 Mon Sep 17 00:00:00 2001 From: Dan Johansson Date: Tue, 13 May 2025 17:02:28 +0200 Subject: [PATCH 04/24] ggml-cpu: Update KleidiAI to v1.6 and fix include directives (llama/13509) Signed-off-by: Dan Johansson --- ggml/src/ggml-cpu/CMakeLists.txt | 4 ++-- ggml/src/ggml-cpu/kleidiai/kernels.h | 1 + ggml/src/ggml-cpu/kleidiai/kleidiai.cpp | 2 ++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index bdaec2881dd..1d4259dae5b 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -385,9 +385,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.5.0") + set(KLEIDIAI_COMMIT_TAG "v1.6.0") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "ea22e1aefb800e9bc8c74d91633cc58e") + set(KLEIDIAI_ARCHIVE_MD5 "75b4ad68f25ab673dcc01065e5a0b05f") if (POLICY CMP0135) cmake_policy(SET CMP0135 NEW) diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h index 5ac02bda7c0..3b268d4a22a 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.h +++ b/ggml/src/ggml-cpu/kleidiai/kernels.h @@ -5,6 +5,7 @@ #pragma once #include +#include #include "ggml.h" enum cpu_feature { diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index f3dffdd6bf5..15f0cd15406 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -3,7 +3,9 @@ // #include #include +#include #include +#include #include #include #if defined(__linux__) From 5b11174a0aa6b5805b2143c001354d478473149c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 13 May 2025 18:04:00 +0300 Subject: [PATCH 05/24] metal : optimize multi-sequence FA vec kernel (llama/13493) * batched-bench : fix pp batch contents * metal : optimize multi-sequence FA vec kernel ggml-ci --- ggml/src/ggml-metal/ggml-metal.metal | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 9cfddf4503a..122ae597371 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3887,6 +3887,11 @@ kernel void kernel_flash_attn_ext_vec( sm[tiisg] = pm[ic + tiisg]; } + // skip -INF blocks + if (simd_max(sm[tiisg]) == -INFINITY) { + continue; + } + // Q*K^T { // each simdgroup processes 1 query and NE (NW/NL) head elements From 7cdf758f48504532806fbeb1536f71b500c17129 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 13 May 2025 18:04:39 +0300 Subject: [PATCH 06/24] metal : use FA-vec kernel up to batch size 20 (llama/13496) * batched-bench : fix pp batch contents * metal : optimize multi-sequence FA vec kernel ggml-ci * metal : use FA-vec kernel up to batch size 20 ggml-ci --- ggml/src/ggml-metal/ggml-metal.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 576f9581bda..f4b3d9cf592 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -4358,7 +4358,7 @@ static bool ggml_metal_encode_node( // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) // for now avoiding mainly to keep the number of templates/kernels a bit lower // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612 - if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) { + if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) { switch (src1->type) { case GGML_TYPE_F16: { From 014d86012408fb46494c591e6b2674dc289be31b Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 14 May 2025 13:15:50 +0900 Subject: [PATCH 07/24] vulkan: workaround FA compile failures on macos (llama/13517) --- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index e6545160d53..16835576814 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -12,6 +12,7 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +layout (constant_id = 0) const uint32_t WorkGroupSize = 128; layout (constant_id = 1) const uint32_t Br = 1; layout (constant_id = 2) const uint32_t Bc = 32; layout (constant_id = 3) const uint32_t D = 32; @@ -19,7 +20,7 @@ layout (constant_id = 3) const uint32_t D = 32; layout (constant_id = 5) const uint32_t D_split = 16; const uint32_t D_per_thread = D / D_split; -const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split; +const uint32_t cols_per_iter = WorkGroupSize / D_split; const uint32_t cols_per_thread = Bc / cols_per_iter; layout (push_constant) uniform parameter { @@ -134,8 +135,8 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i return ACC_TYPE(pow(base, ACC_TYPE(exph))); } -shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; -shared vec4 tmpshv4[gl_WorkGroupSize.x]; +shared FLOAT_TYPE tmpsh[WorkGroupSize]; +shared vec4 tmpshv4[WorkGroupSize]; shared float masksh[Bc][Br]; shared vec4 Qf[Br][D / 4]; From bf561898e7a8aadf93c77c2465f594d370dbf2ff Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 14 May 2025 18:55:26 +0900 Subject: [PATCH 08/24] vulkan: KHR_coopmat flash attention (llama/13506) This shader uses coopmat1 to do the Q*K^T multiply. The P*V multiply is more difficult for various reasons so I haven't done it. Performance for this shader is around 2.5x better than for the scalar shader when doing prompt processing. Some of the benefit may be from other optimizations like staging through shared memory, or splitting by rows. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 237 ++++++-- .../vulkan-shaders/flash_attn_cm1.comp | 506 ++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 13 +- 3 files changed, 702 insertions(+), 54 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e2b357fdc15..0856a112283 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -288,6 +288,9 @@ struct vk_device_struct { bool coopmat_acc_f32_support {}; bool coopmat_acc_f16_support {}; bool coopmat_bf16_support {}; + bool coopmat_support_16x16x16_f16acc {}; + bool coopmat_support_16x16x16_f32acc {}; + bool coopmat1_fa_support {}; uint32_t coopmat_m; uint32_t coopmat_n; uint32_t coopmat_k; @@ -410,6 +413,13 @@ struct vk_device_struct { vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; @@ -1588,19 +1598,36 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events ); } +enum FaCodePath { + FA_SCALAR, + FA_COOPMAT1, + FA_COOPMAT2, +}; + // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; static constexpr uint32_t scalar_flash_attention_num_large_rows = 8; -static uint32_t get_fa_num_small_rows(bool scalar) { - return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows; +// The FA coopmat1 shader assumes 16x16x16 matrix multiply support. +// 128 threads split into four subgroups, each subgroup does 1/4 +// of the Bc dimension. +static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16; +static constexpr uint32_t scalar_flash_attention_Bc = 64; +static constexpr uint32_t scalar_flash_attention_workgroup_size = 128; + +static uint32_t get_fa_num_small_rows(FaCodePath path) { + if (path == FA_COOPMAT2) { + return flash_attention_num_small_rows; + } else { + return scalar_flash_attention_num_small_rows; + } } -static std::array fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { +static std::array fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { GGML_UNUSED(clamp); - if (scalar) { + if (path == FA_SCALAR) { if (small_rows) { return {scalar_flash_attention_num_small_rows, 64}; } else { @@ -1608,9 +1635,17 @@ static std::array fa_rows_cols(bool scalar, uint32_t D, uint32_t cl } } + if (path == FA_COOPMAT1) { + if (small_rows) { + return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc}; + } else { + return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc}; + } + } + // small rows, large cols if (small_rows) { - return {get_fa_num_small_rows(scalar), 32}; + return {get_fa_num_small_rows(FA_COOPMAT2), 32}; } // small cols to reduce register count @@ -1907,17 +1942,19 @@ static void ggml_vk_load_shaders(vk_device& device) { parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; - auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { - return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1}; + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { + return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1}; }; - auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // can't use 256 for D==80. // For scalar, use 128 (arbitrary) - uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128); - auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows); + uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) + ? scalar_flash_attention_workgroup_size + : ((small_rows && (D % 32) == 0) ? 256 : 128); + auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows); // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. @@ -1929,36 +1966,43 @@ static void ggml_vk_load_shaders(vk_device& device) { return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split}; }; -#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ - -#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256) - - CREATE_FA(GGML_TYPE_F16, f16, true, ) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, ) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, ) +#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + +#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256) + + CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->coopmat1_fa_support) { + CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) + } +#endif #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { - CREATE_FA(GGML_TYPE_F16, f16, false, _cm2) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2) + CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2) } #endif #undef CREATE_FA2 @@ -2041,17 +2085,17 @@ static void ggml_vk_load_shaders(vk_device& device) { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ // Create 2 variants, {f16,f32} accumulator #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ @@ -3009,6 +3053,11 @@ static vk_device ggml_vk_get_device(size_t idx) { #if defined(VK_KHR_cooperative_matrix) device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; + + // coopmat1 fa shader currently assumes 32 invocations per subgroup + device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support && + device->subgroup_size_control && device->subgroup_min_size <= 32 && + device->subgroup_max_size >= 32; #endif if (coopmat2_support) { @@ -3143,6 +3192,9 @@ static vk_device ggml_vk_get_device(size_t idx) { // Only enable if shape is identical device->coopmat_acc_f32_support = true; } + if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { + device->coopmat_support_16x16x16_f32acc = true; + } } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 && (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) { // coopmat sizes not set yet @@ -3155,6 +3207,9 @@ static vk_device ggml_vk_get_device(size_t idx) { // Only enable if shape is identical device->coopmat_acc_f16_support = true; } + if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { + device->coopmat_support_16x16x16_f16acc = true; + } } } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 && (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 && @@ -5688,6 +5743,36 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) { + // Needs to be kept up to date on shader changes + const uint32_t wg_size = scalar_flash_attention_workgroup_size; + const uint32_t Br = scalar_flash_attention_num_large_rows; + const uint32_t Bc = scalar_flash_attention_Bc; + + const uint32_t acctype = f32acc ? 4 : 2; + const uint32_t f16vec4 = 8; + + const uint32_t tmpsh = wg_size * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * acctype; + + const uint32_t Qf = Br * (D / 4 + 2) * f16vec4; + + const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; + const uint32_t sfsh = Bc * sfshstride * acctype; + + const uint32_t kshstride = D / 4 + 2; + const uint32_t ksh = Bc * kshstride * f16vec4; + + const uint32_t slope = Br * sizeof(float); + + const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); + + return supported; +} + static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; @@ -5738,7 +5823,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx assert(q->type == GGML_TYPE_F32); assert(k->type == v->type); - bool scalar = !ctx->device->coopmat2; + FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 : + ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + + if (path == FA_COOPMAT1) { + const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || + (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); + + const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32); + + if (!coopmat_shape_supported || !coopmat_shmem_supported) { + path = FA_SCALAR; + } + } uint32_t gqa_ratio = 1; uint32_t qk_ratio = neq2 / nek2; @@ -5746,9 +5843,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; - // For scalar FA, we can use the "large" size to accommodate qga. - // For coopmat FA, we always use the small size (which is still pretty large for gqa). - const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false); + // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. + // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). + uint32_t max_gqa; + switch (path) { + case FA_SCALAR: + case FA_COOPMAT1: + // We may switch from coopmat1 to scalar, so use the scalar limit for both + max_gqa = scalar_flash_attention_num_large_rows; + break; + case FA_COOPMAT2: + max_gqa = get_fa_num_small_rows(FA_COOPMAT2); + break; + default: + GGML_ASSERT(0); + } if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { @@ -5761,11 +5870,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } vk_pipeline *pipelines; - // XXX TODO other backends may be changing accumulator precision to default to f32 soon - bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32; - bool small_rows = N <= get_fa_num_small_rows(scalar); + bool small_rows = N <= get_fa_num_small_rows(path); - if (scalar) { + if (small_rows && path == FA_COOPMAT1) { + path = FA_SCALAR; + } + + bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; + + switch (path) { + case FA_SCALAR: switch (D) { case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; @@ -5777,7 +5891,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(!"unsupported D value"); return; } - } else { + break; + case FA_COOPMAT1: + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break; + default: + GGML_ASSERT(!"unsupported D value"); + return; + } + break; + case FA_COOPMAT2: switch (D) { case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break; case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break; @@ -5789,6 +5917,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(!"unsupported D value"); return; } + break; + default: + GGML_ASSERT(0); } assert(pipelines); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp new file mode 100644 index 00000000000..8b86b623bd9 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -0,0 +1,506 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable + +#include "types.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t D = 32; + +layout (constant_id = 5) const uint32_t D_split = 16; + +const uint32_t D_per_thread = D / D_split; +const uint32_t row_split = 4; +const uint32_t rows_per_thread = Br / row_split; +const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; +const uint32_t cols_per_thread = Bc / cols_per_iter; + +layout (push_constant) uniform parameter { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; +} p; + +layout (binding = 0) readonly buffer Q {float data_q[];}; +layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; +layout (binding = 1) readonly buffer K {float16_t data_k[];}; +layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; +layout (binding = 2) readonly buffer V {float16_t data_v[];}; +layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 3) readonly buffer M {float16_t data_m[];}; +layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; + +#if defined(A_TYPE_PACKED16) +#define BINDING_IDX_K 0 +#define BINDING_IDX_V 1 +layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; +#endif + +#if defined(DATA_A_Q4_0) +#define BLOCK_BYTE_SIZE 18 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); +} +#endif + +#if defined(DATA_A_Q8_0) +#define BLOCK_BYTE_SIZE 34 +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * D + c; + data_o[o_offset + offset] = D_TYPE(elem); + return elem; +} + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + +// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd +const uint32_t MatBr = 16; +const uint32_t MatBc = 16; + +shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; +shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; + +const uint32_t qstride = D / 4 + 2; // in units of f16vec4 +shared f16vec4 Qf[Br * qstride]; + +// Avoid padding for D==256 to make it fit in 48KB shmem. +const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; +shared ACC_TYPE sfsh[Bc * sfshstride]; + +const uint32_t kshstride = D / 4 + 2; // in units of f16vec4 +shared f16vec4 ksh[Bc * kshstride]; + +shared float slope[Br]; + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + const uint32_t tid = gl_LocalInvocationIndex; + const uint32_t N = p.N; + const uint32_t KV = p.KV; + + const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; + const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; + const uint32_t d_tid = gl_LocalInvocationIndex % D_split; + const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + +#define tile_row(r) (row_tid * rows_per_thread + (r)) + + uint32_t i = gl_WorkGroupID.x; + uint32_t split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + + const uint32_t Tr = CEIL_DIV(N, Br); + + const uint32_t start_j = split_k_index * p.split_kv / Bc; + const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); + + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; + const uint32_t iq3 = gl_WorkGroupID.z; + + // broadcast factors + const uint32_t rk2 = p.neq2/p.nek2; + const uint32_t rk3 = p.neq3/p.nek3; + + const uint32_t rv2 = p.neq2/p.nev2; + const uint32_t rv3 = p.neq3/p.nev3; + + // k indices + const uint32_t ik3 = iq3 / rk3; + const uint32_t ik2 = iq2 / rk2; + + // v indices + const uint32_t iv3 = iq3 / rv3; + const uint32_t iv2 = iq2 / rv2; + + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; + uint32_t k_stride = p.nb11; + uint32_t v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; + + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + + [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (D / 4); + uint32_t r = (idx + tid) / (D / 4); + if (r < Br && d < D / 4 && + i * Br + r < N) { + Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + } + } + barrier(); + + ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = ACC_TYPEV4(0.0); + } + } + + float Lf[rows_per_thread], Mf[rows_per_thread]; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lf[r] = 0; + Mf[r] = NEG_FLT_MAX_OVER_2; + } + + // ALiBi + if (p.max_bias > 0.0f) { + if (tid < Br) { + uint r = tid; + slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + } + barrier(); + } else { + if (tid < Br) { + uint r = tid; + slope[r] = 1.0; + } + barrier(); + } + +#if BLOCK_SIZE > 1 + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; +#else + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; +#endif + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + [[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (D / 4); + uint32_t c = (idx + tid) / (D / 4); + if (c < Bc && d < D / 4) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); +#else + f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); +#endif + + ksh[c * kshstride + d] = K_Tf; + } + } + barrier(); + + // K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br + // Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 + // This is written transposed in order to allow for N being 8 if implementations need it + coopmat SfMat = coopmat(0); + coopmat KMat; + coopmat QMat; + + for (uint32_t d = 0; d < D / 16; ++d) { + coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); + + uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; + coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + + SfMat = coopMatMulAdd(KMat, QMat, SfMat); + } + + uint coord = gl_SubgroupID * MatBc * sfshstride; + coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor); + barrier(); + + if (p.logit_softcap != 0.0f) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / Br; + uint32_t r = (idx + tid) % Br; + if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { + sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r])); + } + } + barrier(); + } + + if (p.mask != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { + sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)])); + } + } + barrier(); + } + + float eMf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride]; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride])); + } + float Moldf = Mf[r]; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf, Moldf); + eMf[r] = exp(Moldf - Mf[r]); + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = float16_t(eMf[r]) * Of[r][d]; + } + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lf[r] = eMf[r]*Lf[r]; + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + float Pf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); + Lf[r] += Pf[r]; + } + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf); + } + } + } + + barrier(); + } + + // reduce across threads + + float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE M = Mf[r]; + tmpsh[tid] = M; + // Compute max across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + M = max(M, tmpsh[tid ^ s]); + barrier(); + tmpsh[tid] = M; + barrier(); + } + rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Moldf[r] = Mf[r]; + + // M = max(rowmax, Mold) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf[r], Moldf[r]); + eMf[r] = exp(Moldf[r] - Mf[r]); + + Lf[r] = eMf[r]*Lf[r]; + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE L = Lf[r]; + tmpsh[tid] = L; + // Compute sum across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + L += tmpsh[tid ^ s]; + barrier(); + tmpsh[tid] = L; + barrier(); + } + Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + + Of[r][d] = float16_t(eMf[r]) * Of[r][d]; + tmpshv4[tid] = Of[r][d]; + + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + Of[r][d] += tmpshv4[tid ^ s]; + barrier(); + tmpshv4[tid] = Of[r][d]; + barrier(); + } + Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + } + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + uint32_t o_offset = D * p.ne1 * split_k_index; + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + } + } + } + } + + o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + + return; + } + + float Lfrcp[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lfrcp[r] = 1.0 / Lf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] *= float16_t(Lfrcp[r]); + } + } + + uint32_t o_offset = iq3*p.ne2*p.ne1; + + if (p.gqa_ratio > 1) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + } + } + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (i * Br + tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + } + } + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index d196137eb29..9361e2ac83b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -215,7 +215,7 @@ static std::mutex compile_count_mutex; static std::condition_variable compile_count_cond; void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { - std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); + std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); std::string out_fname = join_paths(output_dir, name + ".spv"); std::string in_path = join_paths(input_dir, in_fname); @@ -424,6 +424,7 @@ void process_shaders() { // flash attention for (const auto& f16acc : {false, true}) { std::string acctype = f16acc ? "float16_t" : "float"; + std::string acctypev4 = f16acc ? "f16vec4" : "vec4"; for (const auto& tname : type_names) { if (tname == "f32") { @@ -440,6 +441,16 @@ void process_shaders() { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); } +#endif +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); + } #endif if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", From 4611d5581c3be1ee7cb269b038aa4bdd8853dc39 Mon Sep 17 00:00:00 2001 From: bandoti <141645996+bandoti@users.noreply.github.com> Date: Wed, 14 May 2025 07:53:57 -0300 Subject: [PATCH 09/24] cmake: simplify vulkan shader test logic (llama/13263) --- ggml/src/ggml-vulkan/CMakeLists.txt | 167 ++++++++---------- .../ggml-vulkan/vulkan-shaders/CMakeLists.txt | 17 ++ 2 files changed, 95 insertions(+), 89 deletions(-) diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 31816219c06..16e10a9f399 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -15,6 +15,32 @@ function(detect_host_compiler) set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE) endfunction() +# Function to test shader extension support +# Parameters: +# EXTENSION_NAME - Name of the extension to test (e.g., "GL_EXT_integer_dot_product") +# TEST_SHADER_FILE - Path to the test shader file +# RESULT_VARIABLE - Name of the variable to set (ON/OFF) based on test result +function(test_shader_extension_support EXTENSION_NAME TEST_SHADER_FILE RESULT_VARIABLE) + execute_process( + COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${TEST_SHADER_FILE}" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error + ) + + if (${glslc_error} MATCHES ".*extension not supported: ${EXTENSION_NAME}.*") + message(STATUS "${EXTENSION_NAME} not supported by glslc") + set(${RESULT_VARIABLE} OFF PARENT_SCOPE) + else() + message(STATUS "${EXTENSION_NAME} supported by glslc") + set(${RESULT_VARIABLE} ON PARENT_SCOPE) + add_compile_definitions(${RESULT_VARIABLE}) + + # Ensure the extension support is forwarded to vulkan-shaders-gen + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -D${RESULT_VARIABLE}=ON) + set(VULKAN_SHADER_GEN_CMAKE_ARGS "${VULKAN_SHADER_GEN_CMAKE_ARGS}" PARENT_SCOPE) + endif() +endfunction() + if (Vulkan_FOUND) message(STATUS "Vulkan found") @@ -23,69 +49,35 @@ if (Vulkan_FOUND) ../../include/ggml-vulkan.h ) - # Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) - - if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*") - message(STATUS "GL_KHR_cooperative_matrix not supported by glslc") - set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_KHR_cooperative_matrix supported by glslc") - set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) - endif() - - # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) - - if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*") - message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc") - set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_NV_cooperative_matrix2 supported by glslc") - set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - endif() + set(VULKAN_SHADER_GEN_CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} + -DCMAKE_RUNTIME_OUTPUT_DIRECTORY=${CMAKE_RUNTIME_OUTPUT_DIRECTORY} + ) - # Compile a test shader to determine whether GL_EXT_integer_dot_product is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) + # Test all shader extensions + test_shader_extension_support( + "GL_KHR_cooperative_matrix" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" + "GGML_VULKAN_COOPMAT_GLSLC_SUPPORT" + ) - if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*") - message(STATUS "GL_EXT_integer_dot_product not supported by glslc") - set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_EXT_integer_dot_product supported by glslc") - set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - endif() + test_shader_extension_support( + "GL_NV_cooperative_matrix2" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" + "GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT" + ) - # Compile a test shader to determine whether GL_EXT_bfloat16 is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) + test_shader_extension_support( + "GL_EXT_integer_dot_product" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp" + "GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT" + ) - if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*") - message(STATUS "GL_EXT_bfloat16 not supported by glslc") - set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_EXT_bfloat16 supported by glslc") - set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) - endif() + test_shader_extension_support( + "GL_EXT_bfloat16" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp" + "GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT" + ) target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) @@ -124,16 +116,8 @@ if (Vulkan_FOUND) add_compile_definitions(GGML_VULKAN_RUN_TESTS) endif() - if (NOT CMAKE_CROSSCOMPILING) - add_subdirectory(vulkan-shaders) - if (MSVC) - foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES}) - string(TOUPPER ${CONFIG} CONFIG) - set_target_properties(vulkan-shaders-gen PROPERTIES - RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) - endforeach() - endif() - else() + # Set up toolchain for host compilation whether cross-compiling or not + if (CMAKE_CROSSCOMPILING) if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN) set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN}) else() @@ -146,25 +130,31 @@ if (Vulkan_FOUND) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY) set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake) endif() - message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}") + else() + # For non-cross-compiling, use empty toolchain (use host compiler) + set(HOST_CMAKE_TOOLCHAIN_FILE "") + endif() + + # Always use ExternalProject_Add approach + include(ExternalProject) - include(ExternalProject) - # Native build through ExternalProject_Add - ExternalProject_Add( - vulkan-shaders-gen - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders - CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE} - -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} - -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT} - -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT} - -DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT} - -DGGML_VULKAN_BFLOAT16_GLSLC_SUPPORT=${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT} - BUILD_COMMAND ${CMAKE_COMMAND} --build . - INSTALL_COMMAND ${CMAKE_COMMAND} --install . - INSTALL_DIR ${CMAKE_BINARY_DIR} - ) - ExternalProject_Add_StepTargets(vulkan-shaders-gen build install) + # Add toolchain file if cross-compiling + if (CMAKE_CROSSCOMPILING) + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}) + message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}") endif() + + # Native build through ExternalProject_Add + ExternalProject_Add( + vulkan-shaders-gen + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders + CMAKE_ARGS ${VULKAN_SHADER_GEN_CMAKE_ARGS} + BUILD_COMMAND ${CMAKE_COMMAND} --build . + INSTALL_COMMAND ${CMAKE_COMMAND} --install . + INSTALL_DIR ${CMAKE_BINARY_DIR} + ) + ExternalProject_Add_StepTargets(vulkan-shaders-gen build install) + set (_ggml_vk_host_suffix $,.exe,>) set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix}) set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp) @@ -175,9 +165,8 @@ if (Vulkan_FOUND) file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp") set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen) - if (CMAKE_CROSSCOMPILING) - set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install) - endif() + # Add build and install dependencies for all builds + set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install) add_custom_command( OUTPUT ${_ggml_vk_header} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index ad13f69b3fb..e60e9d1e5b5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -5,18 +5,35 @@ find_package (Threads REQUIRED) if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat glslc support") endif() if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat2 glslc support") endif() if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + message(STATUS "Enabling dot glslc support") endif() if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + message(STATUS "Enabling bfloat16 glslc support") endif() + set(TARGET vulkan-shaders-gen) add_executable(${TARGET} vulkan-shaders-gen.cpp) install(TARGETS ${TARGET} RUNTIME) target_compile_features(${TARGET} PRIVATE cxx_std_17) target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) + +# Configure output directories for MSVC builds +if(MSVC) + # Get the main project's runtime output directory if possible + if(DEFINED CMAKE_RUNTIME_OUTPUT_DIRECTORY) + foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES}) + string(TOUPPER ${CONFIG} CONFIG) + set_target_properties(${TARGET} PROPERTIES + RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) + endforeach() + endif() +endif() From cc84fce61860a51b914a2229d2ae9d077e88a91b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 14 May 2025 16:08:20 +0200 Subject: [PATCH 10/24] CUDA: faster Deepseek FA, add Turing support (llama/13435) --- ggml/src/ggml-cuda/fattn-common.cuh | 18 +- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 323 +++++++++++++++++++++------ ggml/src/ggml-cuda/fattn.cu | 3 +- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- 4 files changed, 276 insertions(+), 70 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index b7180d5955c..a4fbd823638 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -678,10 +678,14 @@ void launch_fattn( ) { constexpr int ncols = ncols1 * ncols2; + const bool is_mla = DV == 512; // TODO better parameterization + const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; + GGML_ASSERT(V || is_mla); + const ggml_tensor * mask = dst->src[3]; ggml_tensor * KQV = dst; @@ -689,6 +693,10 @@ void launch_fattn( GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(KQV->type == GGML_TYPE_F32); + GGML_ASSERT( Q->nb[0] == ggml_element_size(Q)); + GGML_ASSERT( K->nb[0] == ggml_element_size(K)); + GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V)); + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); @@ -713,10 +721,10 @@ void launch_fattn( size_t nb12 = K->nb[2]; size_t nb13 = K->nb[3]; - const char * V_data = (const char *) V->data; - size_t nb21 = V->nb[1]; - size_t nb22 = V->nb[2]; - size_t nb23 = V->nb[3]; + const char * V_data = V ? (const char *) V->data : nullptr; + size_t nb21 = V ? V->nb[1] : nb11; + size_t nb22 = V ? V->nb[2] : nb12; + size_t nb23 = V ? V->nb[3] : nb13; if (need_f16_K && K->type != GGML_TYPE_F16) { GGML_ASSERT(ggml_is_contiguously_allocated(K)); @@ -733,7 +741,7 @@ void launch_fattn( nb13 = nb13*bs*sizeof(half)/ts; } - if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V && need_f16_V && V->type != GGML_TYPE_F16) { GGML_ASSERT(ggml_is_contiguously_allocated(V)); V_f16.alloc(ggml_nelements(V)); to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 491780abd40..be0329d0e0c 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -33,9 +33,30 @@ struct fattn_mma_f16_config< 64, 64> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 32; - static constexpr int nbatch_V2 = 32; - static constexpr int nbatch_combine = 32; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 32; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 32; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 32; + } }; template <> @@ -44,9 +65,30 @@ struct fattn_mma_f16_config< 80, 80> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 40; - static constexpr int nbatch_V2 = 40; - static constexpr int nbatch_combine = 40; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 40; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 40; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 40; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 40; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 40; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 40; + } }; template <> @@ -55,9 +97,30 @@ struct fattn_mma_f16_config< 96, 96> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 48; - static constexpr int nbatch_V2 = 48; - static constexpr int nbatch_combine = 48; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 48; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 48; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 48; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 48; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 48; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 48; + } }; template <> @@ -66,9 +129,30 @@ struct fattn_mma_f16_config<112, 112> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 56; - static constexpr int nbatch_V2 = 56; - static constexpr int nbatch_combine = 56; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 56; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 56; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 56; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 56; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 56; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 56; + } }; template <> @@ -77,9 +161,30 @@ struct fattn_mma_f16_config<128, 128> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 64; - static constexpr int nbatch_V2 = 64; - static constexpr int nbatch_combine = 64; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 64; + } }; template <> @@ -88,9 +193,38 @@ struct fattn_mma_f16_config<256, 256> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 128; - static constexpr int nbatch_V2 = 128; - static constexpr int nbatch_combine = 128; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 128; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 128; + } + + static int get_nbatch_combine_host(const int cc, const int ncols) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { + return ncols <= 16 ? 128 : 64; + } + return 64; + } + + static constexpr __device__ int get_nbatch_combine_device(int ncols) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + return ncols <= 16 ? 128 : 64; +#else + GGML_UNUSED(ncols); + return 128; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING + } }; template <> @@ -99,9 +233,44 @@ struct fattn_mma_f16_config<576, 512> { static constexpr int nwarps_max = 8; static constexpr bool Q_in_reg = false; static constexpr int nstages_target = 1; - static constexpr int nbatch_K2 = 160; - static constexpr int nbatch_V2 = 128; - static constexpr int nbatch_combine = 128; + + static int get_nbatch_K2_host(const int cc, const int ncols) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { + return ncols <= 16 ? 96 : 160; + } + return ncols <= 16 ? 288 : 160; + } + + static constexpr __device__ int get_nbatch_K2_device(int ncols) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + return ncols <= 16 ? 96 : 160; +#else + return ncols <= 16 ? 288 : 160; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING + } + + static int get_nbatch_V2_host(const int cc, const int ncols) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { + return ncols <= 16 ? 64 : 128; + } + return ncols <= 16 ? 256 : 128; + } + + static constexpr __device__ int get_nbatch_V2_device(int ncols) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + return ncols <= 16 ? 64 : 128; +#else + return ncols <= 16 ? 256 : 128; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 128; + } }; // ------------------------------------------------------------------------------------------------------------------ @@ -120,7 +289,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); - auto load = [&] __device__ (const int n) { + auto load = [&] __device__ (auto n) { const int stride_k = WARP_SIZE >> n; const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); @@ -223,7 +392,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -261,10 +430,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int ncols = ncols1 * ncols2; + constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); + constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); - constexpr int stride_tile_Q = DKQ/2 + 4; - constexpr int stride_tile_K = c::nbatch_K2 + 4; - constexpr int stride_tile_V = c::nbatch_V2 + 4; + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = nbatch_K2 + 4; + + static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); + constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; const int k_VKQ_0 = kb0 * c::nbatch_fa; tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; @@ -275,12 +449,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; if constexpr (nstages > 1) { - static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); + static_assert(!mla, "multi-stage loading not implemented for MLA"); + static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); constexpr bool use_cp_async = true; cp_async_wait_all(); __syncthreads(); flash_attn_ext_f16_load_tile - (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V); + (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V); } else { constexpr bool use_cp_async = nstages == 1; if (ncols2 > 1 || mask_h2) { @@ -289,8 +464,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } #pragma unroll - for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) { - const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2; + for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) { + const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; const int k0_diff = k0_stop - k0_start; if (nstages <= 1) { @@ -537,16 +712,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); } flash_attn_ext_f16_load_tile - (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K); + (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); } } + + // For MLA K and V have the same data. + // Therefore, iterate over V in reverse and re-use the data if possible. + static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); + constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; #pragma unroll - for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) { - const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV; - const int i0_diff = i0_stop - i0_start; + for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) { + const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0; + const int i0_diff = i0_stop - i0_start; - if (nstages <= 1) { + if (nstages <= 1 && i0_start < reusable_cutoff) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); @@ -555,6 +735,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } __syncthreads(); } + const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; // Calculate VKQ tile: #pragma unroll @@ -565,7 +746,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0 = k00 + (threadIdx.y % np)*tile_A::J; tile_A A; - load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); if (ntiles == 1) { mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); } else { @@ -596,7 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // NEW_MMA_AVAILABLE } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -632,13 +813,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); + constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); - constexpr int stride_tile_Q = DKQ/2 + 4; - constexpr int stride_tile_K = c::nbatch_K2 + 4; - constexpr int stride_tile_V = c::nbatch_V2 + 4; + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = nbatch_K2 + 4; + static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); + constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; extern __shared__ half2 tile_Q[]; @@ -726,26 +910,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Preload mask and K data for first iteration when using cp_async with multiple stages: if constexpr (nstages > 1) { - static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); + static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); constexpr bool use_cp_async = true; if (ncols2 > 1 || mask_h2) { flash_attn_ext_f16_load_mask (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); } flash_attn_ext_f16_load_tile - (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K); + (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); } // Iterate over ne11 == previous tokens: for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. constexpr bool last_iter = true; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); } @@ -774,7 +958,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // So also write VKQ accumulators to shared memory in column-major format if np == 1. - constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4; + constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols); constexpr int tile_stride = nbatch_combine + 4; static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); @@ -1012,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #endif // NEW_MMA_AVAILABLE } -template +template __launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -1057,6 +1241,14 @@ static __global__ void flash_attn_ext_f16( NO_DEVICE_CODE; return; } +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + if (ncols1*ncols2 > 32) { + NO_DEVICE_CODE; + return; + } +#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING + + static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); typedef fattn_mma_f16_config c; @@ -1067,9 +1259,10 @@ static __global__ void flash_attn_ext_f16( const int stride_Q1 = nb01 / sizeof(float2); const int stride_Q2 = nb02 / sizeof(float2); const int stride_K = nb11 / sizeof(half2); - const int stride_V = nb21 / sizeof(half2); const int stride_mask = nb31 / sizeof(half2); + const int stride_V = mla ? stride_K : nb21 / sizeof(half2); + const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; @@ -1092,10 +1285,11 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; @@ -1104,12 +1298,12 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } @@ -1130,10 +1324,11 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; @@ -1141,7 +1336,7 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else @@ -1167,10 +1362,6 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml typedef fattn_mma_f16_config c; - constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2; - constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2; - constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine; - const int nstages = cp_async_available(cc) ? c::nstages_target : 0; constexpr int ncols = ncols1 * ncols2; @@ -1180,15 +1371,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; + constexpr bool mla = DKQ == 576; + + const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols); + const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols); + const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols); + static_assert(DKQ % tile_B::J == 0, "bad DKQ"); static_assert(DV % tile_A::J == 0, "bad DV"); static_assert(ncols % cols_per_warp == 0, "bad ncols"); - const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2); - const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2); - const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); - const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); - const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); + const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); + const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); + const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; @@ -1202,7 +1399,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1213,7 +1410,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 9c5c803d02b..6bc0096cc65 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -10,6 +10,7 @@ template static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const ggml_tensor * Q = dst->src[0]; if constexpr (ncols2 <= 8) { @@ -24,7 +25,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con return; } - if (Q->ne[1] <= 32/ncols2) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b4b85abcda9..02dc8c12dbd 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3222,7 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g #endif // FLASH_ATTN_AVAILABLE if (op->src[1]->ne[0] != op->src[2]->ne[0]) { const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; - if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) { + if (!new_mma_available(cc)) { return false; } const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2]; From bd65c3bac989a6acd0623e0d3edbdd2e091df303 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 14 May 2025 16:41:02 +0200 Subject: [PATCH 11/24] CUDA: fix crash on large batch size for quant. MoE (llama/13537) --- ggml/src/ggml-cuda/mmq.cu | 2 ++ ggml/src/ggml-cuda/quantize.cu | 13 +++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index e1cf843de1a..2db5b4ab0f0 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -122,6 +122,7 @@ void ggml_cuda_mul_mat_q( const int64_t s13 = src1->nb[3] / ts_src1; quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); + CUDA_CHECK(cudaGetLastError()); } const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); @@ -205,6 +206,7 @@ void ggml_cuda_mul_mat_q( const int64_t s13 = src1->nb[2] / ts_src1; quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + CUDA_CHECK(cudaGetLastError()); } const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index cb93181455d..a0b03a740d7 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1( constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; - const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4; + const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4; if (i0 >= ne0) { return; } - const int64_t i1 = blockIdx.y; + const int64_t i1 = blockIdx.x; const int64_t i2 = blockIdx.z % ne2; const int64_t i3 = blockIdx.z / ne2; @@ -75,8 +75,8 @@ static __global__ void quantize_mmq_q8_1( block_q8_1_mmq * y = (block_q8_1_mmq *) vy; - const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel - const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel + const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel + const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel const int64_t iqs = i0 % (4*QK8_1); // quant index in block // Load 4 floats per thread and calculate max. abs. value between them: @@ -166,8 +166,9 @@ void quantize_mmq_q8_1_cuda( GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ne0 % (4*QK8_1) == 0); - const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); - const dim3 num_blocks(block_num_x, ne1, ne2*ne3); + // ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid: + const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); + const dim3 num_blocks(ne1, block_num_y, ne2*ne3); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); switch (mmq_get_q8_1_ds_layout(type_src0)) { case MMQ_Q8_1_DS_LAYOUT_D4: From cd5adba05bbfeb24849f763e9fa7f392f841d016 Mon Sep 17 00:00:00 2001 From: Yibo Cai Date: Thu, 15 May 2025 03:53:52 +0800 Subject: [PATCH 12/24] arm64: optimize q6_k_q8_k kernel with i8mm (llama/13519) This PR improves q6_k_q8_k gemm kernel with arm64 i8mm instruction. Tested on neoverse-n2 with llama3 8b q6_k quantization model. - 40% ~ 54% S_PP uplift for all batch sizes - 16% ~ 47% S_TG uplift for batch size 4 and above Perplexity doesn't change with this PR. ``` // tested on neoverse-n2 $ llama-batched-bench \ -m Meta-Llama-3-8B-Instruct-Q6_K.gguf \ --no-mmap -fa \ -c 8192 -b 4096 -ub 512 -npp 128 -ntg 128 \ -npl 1,2,4,8,16,32 \ -t 64 --------------------------------------------------------------------- | PP | TG | B | S_PP t/s | S_TG t/s | | | | | original | this pr | original | this pr | |-------|--------|------|----------|----------|----------|----------| | 128 | 128 | 1 | 78.52 | 109.18 | 18.63 | 18.88 | | 128 | 128 | 2 | 84.62 | 123.94 | 34.54 | 36.92 | | 128 | 128 | 4 | 84.36 | 122.49 | 52.65 | 61.32 | | 128 | 128 | 8 | 90.52 | 138.87 | 63.46 | 84.41 | | 128 | 128 | 16 | 90.11 | 138.56 | 71.04 | 101.33 | | 128 | 128 | 32 | 89.81 | 137.79 | 75.14 | 110.47 | --------------------------------------------------------------------- ``` --- ggml/src/ggml-cpu/ggml-cpu-quants.c | 195 ++++++++++++++++++++++++++++ ggml/src/ggml-cpu/ggml-cpu.c | 4 + 2 files changed, 199 insertions(+) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index ccd0651ebc7..a89ce9bb1e9 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -8519,7 +8519,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); +#ifdef __ARM_FEATURE_MATMUL_INT8 + assert((nrc == 2) || (nrc == 1)); +#else assert(nrc == 1); +#endif UNUSED(nrc); UNUSED(bx); UNUSED(by); @@ -8530,6 +8534,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q6_K * GGML_RESTRICT x0 = x; + const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx); + const block_q8_K * GGML_RESTRICT y0 = y; + const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by); + + float32x4_t vfsum = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) { + const uint8_t * GGML_RESTRICT ql0 = x0->ql; + const uint8_t * GGML_RESTRICT ql1 = x1->ql; + const uint8_t * GGML_RESTRICT qh0 = x0->qh; + const uint8_t * GGML_RESTRICT qh1 = x1->qh; + const int8_t * GGML_RESTRICT qy0 = y0->qs; + const int8_t * GGML_RESTRICT qy1 = y1->qs; + + const uint8x16_t mone = vdupq_n_u8(0x30); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + int32x4_t visum = vdupq_n_s32(0); + + // process 8 blocks per iteration, totally 16 blocks + for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) { + int8x16_t vx0[8], vx1[8]; + + // de-quantize vx0[8] + { + const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0); + const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0); + + uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4)); + uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4)); + uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2)); + uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2)); + + vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0)); + vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1)); + vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2)); + vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3)); + + q6h_0 = vandq_u8(mone, qh_bits.val[0]); + q6h_1 = vandq_u8(mone, qh_bits.val[1]); + q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2)); + q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2)); + + vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0)); + vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1)); + vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2)); + vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3)); + } + + // de-quantize vx1[8] + { + const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1); + const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1); + + uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4)); + uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4)); + uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2)); + uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2)); + + vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0)); + vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1)); + vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2)); + vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3)); + + q6h_0 = vandq_u8(mone, qh_bits.val[0]); + q6h_1 = vandq_u8(mone, qh_bits.val[1]); + q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2)); + q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2)); + + vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0)); + vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1)); + vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2)); + vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3)); + } + + // process 16 elements (one block with same scale) per iteration + // - vx = concat(ql, qh) - 32 + // - r1,r2,r3,r4 = smmla(vx, vy) + for (int k = 0; k < 8; ++k) { + const int blk = j * 8 + k; + + const int8x16_t vy0 = vld1q_s8(qy0); + const int8x16_t vy1 = vld1q_s8(qy1); + qy0 += 16; + qy1 += 16; + + const int32x4_t block_scale = { + x0->scales[blk], + x0->scales[blk], + x1->scales[blk], + x1->scales[blk], + }; + + // calculate four results at once with outer product + const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k]))); + const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k]))); + const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1))); + const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1))); + int32x4_t vr = vdupq_n_s32(0); + vr = vmmlaq_s32(vr, vx_l, vy_l); + vr = vmmlaq_s32(vr, vx_h, vy_h); + + // apply block scale, will NOT overflow + // block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits + visum = vmlaq_s32(visum, vr, block_scale); + } + } + + // adjust bias, apply superblock scale + { + int32_t bias[4]; +#ifdef __ARM_FEATURE_SVE + const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8); + const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8); + const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums); + const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8); + const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums); + const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8); + const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales)); + const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8)); + const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales)); + const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8)); + const svint64_t zero = svdup_n_s64(0); + bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0), + svdot_s64(zero, y0_q8sums_1, x0_q6scales_1))); + bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0), + svdot_s64(zero, y1_q8sums_1, x0_q6scales_1))); + bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0), + svdot_s64(zero, y0_q8sums_1, x1_q6scales_1))); + bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0), + svdot_s64(zero, y1_q8sums_1, x1_q6scales_1))); +#else + // NEON doesn't support int16 dot product, fallback to separated mul and add + const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums); + const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums); + + int8x16_t scales_s8 = vld1q_s8(x0->scales); + const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}}; + scales_s8 = vld1q_s8(x1->scales); + const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}}; + + int32x4_t prod; + prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])), + vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])), + vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1])))); + bias[0] = vaddvq_s32(prod); + prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])), + vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])), + vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1])))); + bias[1] = vaddvq_s32(prod); + prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])), + vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])), + vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1])))); + bias[2] = vaddvq_s32(prod); + prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])), + vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])), + vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1])))); + bias[3] = vaddvq_s32(prod); + +#endif + const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32); + + const float32x4_t superblock_scale = { + GGML_FP16_TO_FP32(x0->d) * y0->d, + GGML_FP16_TO_FP32(x0->d) * y1->d, + GGML_FP16_TO_FP32(x1->d) * y0->d, + GGML_FP16_TO_FP32(x1->d) * y1->d, + }; + + visum = vsubq_s32(visum, vibias); + vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale); + } + } + + // vfsum = ABCD -> ACBD + // AC -> s, BD -> (s+bs) + vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2)); + vst1_f32(s, vget_low_f32 (vfsum)); + vst1_f32(s + bs, vget_high_f32(vfsum)); + + return; + } +#endif + #ifdef __ARM_FEATURE_SVE const int vector_length = ggml_cpu_get_sve_cnt()*8; float sum = 0; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index a30e67f2279..133b50606bc 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -282,7 +282,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .from_float = quantize_row_q6_K, .vec_dot = ggml_vec_dot_q6_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, +#if defined (__ARM_FEATURE_MATMUL_INT8) + .nrows = 2, +#else .nrows = 1, +#endif }, [GGML_TYPE_IQ2_XXS] = { .from_float = NULL, From 0afc371b47c329e367bd1b26156b4c33f0fbe49b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20=C5=9Alusarczyk?= <112692748+lslusarczyk@users.noreply.github.com> Date: Thu, 15 May 2025 16:53:41 +0200 Subject: [PATCH 13/24] sycl: use oneDNN for matrices multiplication (llama/12972) --- ggml/CMakeLists.txt | 1 + ggml/src/ggml-sycl/CMakeLists.txt | 48 ++++---- ggml/src/ggml-sycl/gemm.hpp | 45 +++++-- ggml/src/ggml-sycl/ggml-sycl.cpp | 194 +++++++++++++++++++----------- 4 files changed, 191 insertions(+), 97 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index a8300e16d87..4746d5cb76c 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -193,6 +193,7 @@ option(GGML_RPC "ggml: use RPC" option(GGML_SYCL "ggml: use SYCL" OFF) option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF) option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON) +option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON) set (GGML_SYCL_TARGET "INTEL" CACHE STRING "ggml: sycl target device") set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 231fb71dab5..a2e26124802 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -49,34 +49,38 @@ endif() target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing") # Link against oneDNN -find_package(DNNL) set(GGML_SYCL_DNNL 0) -if(DNNL_FOUND) - if (NOT DEFINED DNNL_GPU_VENDOR) - # default to intel target - set(DNNL_GPU_VENDOR "INTEL") - if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL") - message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target") +if(GGML_SYCL_DNN) + find_package(DNNL) + if(DNNL_FOUND) + if (NOT DEFINED DNNL_GPU_VENDOR) + # default to intel target + set(DNNL_GPU_VENDOR "INTEL") + if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL") + message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target") + endif() endif() - endif() - # Verify oneDNN was compiled for the same target as llama - if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}") - target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl) - set(GGML_SYCL_DNNL 1) - get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS) - foreach(CONFIG ${CONFIGS}) - get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG}) - message(STATUS "Found oneDNN: ${DNNL_LIB}") - endforeach() + # Verify oneDNN was compiled for the same target as llama + if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}") + target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl) + set(GGML_SYCL_DNNL 1) + get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS) + foreach(CONFIG ${CONFIGS}) + get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG}) + message(STATUS "Found oneDNN: ${DNNL_LIB}") + endforeach() + else() + message(WARNING + "oneDNN must be compiled for the same target as llama.cpp. + llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}. + Disabling oneDNN support.") + endif() else() - message(WARNING - "oneDNN must be compiled for the same target as llama.cpp. - llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}. - Disabling oneDNN support.") + message(STATUS "oneDNN not found, disabling oneDNN support") endif() else() - message(STATUS "oneDNN not found, disabling oneDNN support") + message(STATUS "oneDNN support disabled by the user") endif() target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL}) diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index 4ebbb5b66fb..6cbc7e0f693 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -32,16 +32,36 @@ class DnnlGemmWrapper { else static_assert(0); } - static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k, - const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { + // matrix A has m rows, k columns + // matrix B has k rows, n columns + // nra - number of elements to skip when moving into next row in A + // nrb - number of elements to skip when moving into next row in B + // nca - number of elements to skip when moving into next column in A + // ncb - number of elements to skip when moving into next column in B + // stride_a - number of elements to skip when moving to next A matrix + // stride_b - number of elements to skip when moving to next B matrix + // batches_a - number of A matrices + // batches_b - number of B matrices + static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, + const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a, + const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b, + void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) { + auto stream = ctx.stream_dnnl(q); auto eng = ctx.engine_dnnl(q); - dnnl::memory::dims a_dims = { m, k }; - dnnl::memory::dims b_dims = { k, n }; - dnnl::memory::dims c_dims = { m, n }; - const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); - const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); - const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); + + // { # strides, # rows, # columns } + dnnl::memory::dims a_dims = { batches_a, m, k }; + dnnl::memory::dims b_dims = { batches_b, k, n }; + dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n }; + + // { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column } + dnnl::memory::dims a_strides = { stride_a, nra, nca }; + dnnl::memory::dims b_strides = { stride_b, nrb, ncb }; + + const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides); + const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides); + const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc); dnnl::primitive_attr primitive_attr; primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); @@ -63,6 +83,15 @@ class DnnlGemmWrapper { matmul_prim.execute(stream, matmul_args); } + + // matrices A and B are column major, both having k rows + // matrix A has m column, matrix B has n columns + // output: column major matrix C = A transposed * B + static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, + const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { + + gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1); + } }; #endif diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 0ea729948ec..1205fce0e7c 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -49,6 +49,7 @@ static bool g_sycl_loaded = false; int g_ggml_sycl_debug = 0; int g_ggml_sycl_disable_optimize = 0; int g_ggml_sycl_disable_graph = 0; +int g_ggml_sycl_disable_dnn = 0; int g_ggml_sycl_prioritize_dmmv = 0; static ggml_sycl_device_info ggml_sycl_init() { @@ -196,12 +197,22 @@ static void ggml_check_sycl() try { g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1); g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); + g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize); +#ifdef GGML_SYCL_GRAPH GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph); +#else + GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n"); +#endif +#if GGML_SYCL_DNNL + GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn); +#else + GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n"); +#endif GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); GGML_LOG_INFO("Build with Macros:\n"); #if defined(GGML_SYCL_FORCE_MMQ) @@ -1985,19 +1996,18 @@ inline void ggml_sycl_op_mul_mat_sycl( const int64_t ne00 = src0->ne[0]; const int64_t ne10 = src1->ne[0]; - + GGML_ASSERT(ne00 == ne10); const int64_t row_diff = row_high - row_low; int id; SYCL_CHECK( CHECK_TRY_ERROR(id = get_current_device_id())); -#if !GGML_SYCL_DNNL - const int64_t ne0 = dst->ne[0]; + + const int64_t ne0 = dst->ne[0]; // used by MKL only // the main device has a larger memory buffer to hold the results from all GPUs // ldc == nrows of the matrix that cuBLAS writes into - int ldc = id == ctx.device ? ne0 : row_diff; -#endif + int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only #ifdef GGML_SYCL_F16 bool use_fp16 = true; // TODO(Yu) SYCL capability check @@ -2033,25 +2043,29 @@ inline void ggml_sycl_op_mul_mat_sycl( : src1_as_f16.get(); ggml_sycl_pool_alloc dst_f16(ctx.pool(), row_diff * src1_ncols); -#if !GGML_SYCL_DNNL - const sycl::half alpha_f16 = 1.0f; - const sycl::half beta_f16 = 0.0f; - SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( - *stream, oneapi::math::transpose::trans, - oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10, - &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00, - src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, - dst_f16.get(), dpct::library_data_t::real_half, ldc, - dpct::library_data_t::real_half))); - const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); - to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); -#else - DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr, - DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(), - dst_f16.get(), DnnlGemmWrapper::to_dt(), stream); - const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); - to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); +#if GGML_SYCL_DNNL + if (!g_ggml_sycl_disable_dnn) { + DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr, + DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(), + dst_f16.get(), DnnlGemmWrapper::to_dt(), stream); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); + to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); + } + else #endif + { + const sycl::half alpha_f16 = 1.0f; + const sycl::half beta_f16 = 0.0f; + SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( + *stream, oneapi::math::transpose::trans, + oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10, + &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00, + src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, + dst_f16.get(), dpct::library_data_t::real_half, ldc, + dpct::library_data_t::real_half))); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); + to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); + } } else { // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n"); @@ -2072,18 +2086,22 @@ inline void ggml_sycl_op_mul_mat_sycl( const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); -#if !GGML_SYCL_DNNL - const float alpha = 1.0f; - const float beta = 0.0f; - SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm( - get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff, - src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, - dpct::get_value(&beta, *stream), dst_dd_i, ldc))); -#else - DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, - DnnlGemmWrapper::to_dt(), src0_ddf_i, DnnlGemmWrapper::to_dt(), - dst_dd_i, DnnlGemmWrapper::to_dt(), stream); +#if GGML_SYCL_DNNL + if (!g_ggml_sycl_disable_dnn) { + DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i, + DnnlGemmWrapper::to_dt(), src0_ddf_i, DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); + } + else #endif + { + const float alpha = 1.0f; + const float beta = 0.0f; + SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm( + get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff, + src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, + dpct::get_value(&beta, *stream), dst_dd_i, ldc))); + } } GGML_UNUSED(dst); GGML_UNUSED(src1_ddq_i); @@ -2697,7 +2715,7 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst, +static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst, const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23, size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3, int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) { @@ -2713,7 +2731,7 @@ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::h const uint8_t * src0_bytes = reinterpret_cast(src0_as_f16); const uint8_t * src1_bytes = reinterpret_cast(src1_as_f16); - uint8_t * dst_bytes = reinterpret_cast(dst); + uint8_t * dst_bytes = static_cast(dst); ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03; ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13; @@ -2726,6 +2744,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_TENSOR_BINARY_OP_LOCALS @@ -2766,7 +2785,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons } ggml_sycl_pool_alloc dst_f16(ctx.pool()); - char * dst_t = reinterpret_cast(dst_ddf); dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float; dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float; @@ -2783,42 +2801,83 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); + GGML_ASSERT(ne01 == static_cast(nb1/nb0)); + GGML_ASSERT(ne10 == ne00); // broadcast factors const int64_t r2 = ne12 / ne02; const int64_t r3 = ne13 / ne03; - if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { - // there is no broadcast and src0, src1 are contiguous across dims 2, 3 - SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans, - oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, - src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00, - src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t, - mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type))); - } else { - const int ne23 = ne12 * ne13; - - ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2 * ne23); - ggml_sycl_pool_alloc ptrs_dst(ctx.pool(), 1 * ne23); - ggml_sycl_pool_alloc> matrix_info(ctx.host_pool(), 1); - - sycl::range<3> block_dims(1, ne12, ne13); - queue->submit([&](sycl::handler & cgh) { - const void ** ptrs_src_get = ptrs_src.get(); - void ** ptrs_dst_get = ptrs_dst.get(); - size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half); - size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half); - cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02, - nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1); +#if GGML_SYCL_DNNL + if (!g_ggml_sycl_disable_dnn) { + auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12] + (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) { + + DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10, + src1, DnnlGemmWrapper::to_dt(), s11, 1, s12, + src0, DnnlGemmWrapper::to_dt(), 1, nb01/nb00, nb02/nb00, + dst, DnnlGemmWrapper::to_dt(), queue, batches_a, batches_b); + }; + + if (r2 == 1 && r3 == 1) { + if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { + dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03); + } + else { + for (int64_t ie03 = 0; ie03 < ne03; ++ie03) { + const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes + const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13; + float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float)); + dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02); + } + } + } else { + // iterate over batches from smaller set of matrices (matrix 0) + for (int64_t ie02 = 0; ie02 < ne02; ++ie02) { + for (int64_t ie03 = 0; ie03 < ne03; ++ie03) { + const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half)); + const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3; + float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float)); + dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1); + } + } + } + } + else +#endif + { + if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { + // there is no broadcast and src0, src1 are contiguous across dims 2, 3 + SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans, + oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, + src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00, + src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf, + mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type))); + } else { + const int ne23 = ne12 * ne13; + + ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2 * ne23); + ggml_sycl_pool_alloc ptrs_dst(ctx.pool(), 1 * ne23); + ggml_sycl_pool_alloc> matrix_info(ctx.host_pool(), 1); + + sycl::range<3> block_dims(1, ne12, ne13); + queue->submit([&](sycl::handler & cgh) { + const void ** ptrs_src_get = ptrs_src.get(); + void ** ptrs_dst_get = ptrs_dst.get(); + size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half); + size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half); + cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02, + nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1); + }); }); - }); - SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, - (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, - (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta, - (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get()))); + SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( + *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, + (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, + (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta, + (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get()))); + } } } catch (const sycl::exception & exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; @@ -3713,7 +3772,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_ return GGML_STATUS_SUCCESS; } - sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream())); + sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}}); + model_sycl_graph.begin_recording(*(sycl_ctx->stream())); ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); model_sycl_graph.end_recording(); From 511ae56cbdcbecd3f66931511f0ce58568ea5fc7 Mon Sep 17 00:00:00 2001 From: Svetlozar Georgiev <55534064+sgeor255@users.noreply.github.com> Date: Thu, 15 May 2025 16:35:44 +0100 Subject: [PATCH 14/24] sycl: reordered Q4_K MMVQ (llama/13109) --- ggml/src/ggml-sycl/convert.cpp | 31 ++++++++- ggml/src/ggml-sycl/dequantize.hpp | 80 +++++++++++++++------ ggml/src/ggml-sycl/dmmv.cpp | 8 ++- ggml/src/ggml-sycl/ggml-sycl.cpp | 80 +++++++++++++++++---- ggml/src/ggml-sycl/mmvq.cpp | 31 ++++++++- ggml/src/ggml-sycl/quants.hpp | 22 ++++++ ggml/src/ggml-sycl/vecdotq.hpp | 112 ++++++++++++++++++------------ 7 files changed, 280 insertions(+), 84 deletions(-) diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index b2f8a656933..75bac98e5fb 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -183,6 +183,24 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k, } } +template +static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + const size_t local_size = 32; + const size_t global_size = nb * local_size; + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor scale_local_acc(sycl::range<1>(12), cgh); + + cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)), + [=](sycl::nd_item<1> item_ct1) { + dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb); + }); + }); +} + template static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -504,7 +522,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { case GGML_TYPE_Q3_K: return dequantize_row_q3_K_sycl; case GGML_TYPE_Q4_K: - return dequantize_row_q4_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q4_K_sycl_reorder; + } else { + return dequantize_row_q4_K_sycl; + } case GGML_TYPE_Q5_K: return dequantize_row_q5_K_sycl; case GGML_TYPE_Q6_K: @@ -556,7 +578,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { case GGML_TYPE_Q3_K: return dequantize_row_q3_K_sycl; case GGML_TYPE_Q4_K: - return dequantize_row_q4_K_sycl; + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q4_K_sycl_reorder; + } else { + return dequantize_row_q4_K_sycl; + } case GGML_TYPE_Q5_K: return dequantize_row_q5_K_sycl; case GGML_TYPE_Q6_K: diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index 651c2160d24..64e92f73f26 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -357,6 +357,28 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8 } #endif +template +inline void dequantize_q4_K_common(dst_t * __restrict__ y, const uint8_t * __restrict__ qs_ptr, const float dall, + const float dmin, uint8_t * __restrict__ scales_local, int il, int ir) { + const int is = 2 * il; + constexpr int n = 4; + + uint8_t sc, m; + get_scale_min_k4(is + 0, scales_local, sc, m); + const float d1 = dall * sc; + const float m1 = dmin * m; + + get_scale_min_k4(is + 1, scales_local, sc, m); + const float d2 = dall * sc; + const float m2 = dmin * m; + + sycl::vec q_vec = vec_aligned_load(qs_ptr + 32 * il + n * ir); + for (int l = 0; l < n; ++l) { + y[l + 0] = d1 * (q_vec[l] & 0xF) - m1; + y[l + 32] = d2 * (q_vec[l] >> 4) - m2; + } +} + template static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) { @@ -365,36 +387,22 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri const int64_t i = item_ct1.get_group(2); #if QK_K == 256 - // assume 32 threads const int64_t tid = item_ct1.get_local_id(2); - const int64_t il = tid/8; - const int64_t ir = tid%8; - const int64_t is = 2*il; - const int64_t n = 4; + const int64_t il = tid / 8; + const int64_t ir = tid % 8; - dst_t * y = yy + i*QK_K + 64*il + n*ir; + dst_t * y = yy + i * QK_K + 64 * il + 4 * ir; const sycl::half2 dm = x[i].dm; const float dall = dm[0]; const float dmin = dm[1]; - if (tid < 12) + if (tid < 12) { scales_local[tid] = x[i].scales[tid]; - item_ct1.barrier(sycl::access::fence_space::local_space); - - uint8_t sc, m; - get_scale_min_k4(is + 0, scales_local, sc, m); - const float d1 = dall * sc; - const float m1 = dmin * m; - get_scale_min_k4(is + 1, scales_local, sc, m); - const float d2 = dall * sc; - const float m2 = dmin * m; - - sycl::vec q_vec = vec_aligned_load(x[i].qs + 32*il + n*ir); - for (int l = 0; l < n; ++l) { - y[l + 0] = d1 * (q_vec[l] & 0xF) - m1; - y[l +32] = d2 * (q_vec[l] >> 4) - m2; } + + item_ct1.barrier(sycl::access::fence_space::local_space); + dequantize_q4_K_common(y, x[i].qs, dall, dmin, scales_local, il, ir); #else const int64_t tid = item_ct1.get_local_id(2); const uint8_t * q = x[i].qs; @@ -406,6 +414,36 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri #endif } +template +static void dequantize_block_q4_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t * scales_local, + const sycl::nd_item<1> & item_ct1, int64_t nb) { + const int64_t i = item_ct1.get_group(0); // block index + const int64_t tid = item_ct1.get_local_id(0); // thread index within block + const int64_t il = tid / 8; + const int64_t ir = tid % 8; + + dst_t * y = yy + i * QK_K + 64 * il + 4 * ir; + + const uint8_t * base = static_cast(vx); + const size_t qs_offset = i * (QK_K / 2); + const size_t scales_offset = nb * (QK_K / 2) + i * K_SCALE_SIZE; + const size_t dm_offset = nb * (QK_K / 2) + nb * K_SCALE_SIZE + i * sizeof(ggml_half2); + + const uint8_t * qs_ptr = base + qs_offset; + const uint8_t * scales_ptr = base + scales_offset; + ggml_half2 dm_values = *reinterpret_cast(base + dm_offset); + + const float dall = dm_values.x(); + const float dmin = dm_values.y(); + + if (tid < 12) { + scales_local[tid] = scales_ptr[tid]; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + dequantize_q4_K_common(y, qs_ptr, dall, dmin, scales_local, il, ir); +} + template static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy, const sycl::nd_item<3> &item_ct1) { diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 04a85fa35ff..b58150c687b 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -1129,7 +1129,13 @@ void ggml_sycl_op_dequantize_mul_mat_vec( dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); break; case GGML_TYPE_Q4_K: - dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + // reorder is currently not supported for dmmv + GGML_ABORT("Unimplemented dequantize case case for q4_k reorder"); + } else { + dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } break; case GGML_TYPE_Q5_K: dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 1205fce0e7c..5ff7fa13db0 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -352,7 +352,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, assert(tensor->view_src->buffer->buft == buffer->buft); return GGML_STATUS_SUCCESS; } - if (tensor->type == GGML_TYPE_Q4_0 && !g_ggml_sycl_disable_optimize) { + if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) { ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; tensor->extra = extra; ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx. @@ -2900,6 +2900,8 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: return true; + case GGML_TYPE_Q4_K: + return !g_ggml_sycl_prioritize_dmmv; default: return false; } @@ -2917,6 +2919,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) { inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_K: return true; default: return false; @@ -2942,16 +2945,16 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) { } } -static void reorder_qw(char *data_device, const int ncols, const int nrows, - size_t size, size_t offset, dpct::queue_ptr stream) { - auto tmp_buf = sycl::malloc_shared(size, *stream); +static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, + dpct::queue_ptr stream) { + auto * tmp_buf = sycl::malloc_shared(size, *stream); SYCL_CHECK( CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size) .wait())); GGML_ASSERT((size % sizeof(block_q4_0) == 0)); GGML_ASSERT((offset % sizeof(block_q4_0) == 0)); int offset_blks = offset / sizeof(block_q4_0); - auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2; + auto qs_ptr = data_device + offset_blks * QK4_0 / 2; auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks; stream->parallel_for( @@ -2965,18 +2968,59 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows, *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j]; } *(d_ptr + ib) = x[ib].d; - }); + }).wait_and_throw(); + + sycl::free(tmp_buf, *stream); +} + +static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { + GGML_ASSERT(size % sizeof(block_q4_K) == 0); + GGML_ASSERT(offset % sizeof(block_q4_K) == 0); + + const int nblocks = size / sizeof(block_q4_K); + + auto * tmp_buf = sycl::malloc_shared(size, *stream); + SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait())); + + auto * qs_ptr = data_device; + auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks; + auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks); + + stream->parallel_for(nblocks, [=](auto i) { + const block_q4_K * x = (const block_q4_K *) tmp_buf; + const int ib = i; + + for (int j = 0; j < QK_K / 2; ++j) { + qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j]; + } + + for (int j = 0; j < K_SCALE_SIZE; ++j) { + scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j]; + } + + dm_ptr[ib] = x[ib].dm; + }).wait_and_throw(); sycl::free(tmp_buf, *stream); } static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { - char*data_device = (char*)src0->data; + uint8_t * data_device = (uint8_t *) src0->data; size_t ncols = src0->ne[0]; size_t nrows = src0->ne[1]; size_t size = ggml_nbytes(src0); - reorder_qw(data_device, ncols, nrows, size, 0, stream); + switch (src0->type) { + case GGML_TYPE_Q4_0: + reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); + break; + case GGML_TYPE_Q4_K: + reorder_qw_q4_k(data_device, size, 0, stream); + break; + default: + GGML_ABORT("reorder_qw() called with unsupported type"); + break; + } } static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) { @@ -3019,8 +3063,18 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering } -static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && + src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; +} + +static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && + src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; +} + +static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer); int64_t min_compute_capability = INT_MAX; @@ -3043,13 +3097,9 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor } // check data types and tensor shapes for custom matrix multiplication kernels: - bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; + bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst); - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst); bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 3cade1a42a6..23eeb74da0d 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -24,6 +24,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r const int blocks_per_row = ncols / block_traits::qk; constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi); constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq; + const int nblocks = nrows * (ncols / block_traits::qk); static_assert(blocks_per_subgroup > 0); static_assert(block_elements_per_subgroup > 0); @@ -45,7 +46,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r // x block quant index when casting the quants to int const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup); - partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs); + partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks); } } @@ -739,6 +740,27 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, } } +static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, + nrows, nd_item); + }); + }); +} + + static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -1035,7 +1057,12 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } else { + mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q5_K: mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index a74e30526c1..88ec13ea269 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -56,6 +56,28 @@ template <> struct block_q_t { static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } }; +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK_K; + static constexpr uint32_t qi = QI4_K; + static constexpr uint32_t qr = QR4_K; + static constexpr uint32_t vdr_mmvq = 2; + }; + + static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); } + + static constexpr int get_d_offset(int nrows, int ncols, const int block_index) { + auto nblocks = (nrows * (ncols / traits::qk)); + return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)); + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } + + constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; } + + constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; } +}; + } // namespace ggml_sycl_reordered #endif // GGML_SYCL_QUANTS_HPP diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index cbf664fcf28..ed369931346 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -285,7 +285,7 @@ template <> struct reorder_vec_dot_q_sycl { } __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, - const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const block_q8_1 * __restrict__ bq8_1, const int & iqs, int /* nblocks */) { const uint8_t * bq4_0 = static_cast(vbq) + ibx_offset; const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset)); int v[q4_0_traits::vdr_mmvq]; @@ -303,6 +303,67 @@ template <> struct reorder_vec_dot_q_sycl { }; }; +static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales, + const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1, + const int & iqs) { + int v[2]; + int u[2 * QR4_K]; + float d8[QR4_K]; + + v[0] = q4[0]; + v[1] = q4[4]; + + uint16_t aux[2]; + const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2; + if (j < 2) { + aux[0] = scales[j + 0] & 0x3f3f; + aux[1] = scales[j + 2] & 0x3f3f; + } else { + aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2); + aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2); + } + + const uint8_t * sc = (const uint8_t *) aux; + const uint8_t * m = sc + 2; + + const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); + + for (int i = 0; i < QR4_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = bq8i->ds[0]; + + const int * q8 = (const int *) bq8i->qs + ((iqs / 2) % 4); + u[2 * i + 0] = q8[0]; + u[2 * i + 1] = q8[4]; + } + + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, dm, d8); +} + +template <> struct reorder_vec_dot_q_sycl { + static constexpr ggml_type gtype = GGML_TYPE_Q4_K; + + using q4_k_block = ggml_sycl_reordered::block_q_t; + using q4_k_traits = typename q4_k_block::traits; + + float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, + const block_q8_1 * __restrict__ bq8_1, const int & iqs, int nblocks) { + const int ib = ibx_offset / (QK_K / 2); + + const uint8_t * base = static_cast(vbq); + const uint8_t * qs = base + ibx_offset; + const int total_qs_bytes = nblocks * (QK_K / 2); + const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE; + const ggml_half2 * dms = reinterpret_cast(base + d_offset); + + const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); + const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); + const uint16_t * scales = (const uint16_t *) scs; + + return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs); + } +}; + #define VDR_Q4_0_Q8_1_MMVQ 2 #define VDR_Q4_0_Q8_1_MMQ 4 @@ -649,52 +710,17 @@ vec_dot_q3_K_q8_1(const void *__restrict__ vbq, return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); } -static __dpct_inline__ float -vec_dot_q4_K_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs) { - +static __dpct_inline__ float vec_dot_q4_K_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, + const int & iqs) { #ifndef GGML_QKK_64 - const block_q4_K * bq4_K = (const block_q4_K *) vbq; - - int v[2]; - int u[2*QR4_K]; - float d8[QR4_K]; - // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6 - const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2)); - - // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12 - // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44 - // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 - // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 - - const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); - v[0] = q4[0]; - v[1] = q4[4]; - - const uint16_t * scales = (const uint16_t *)bq4_K->scales; - uint16_t aux[2]; - const int j = bq8_offset/2; - if (j < 2) { - aux[0] = scales[j+0] & 0x3f3f; - aux[1] = scales[j+2] & 0x3f3f; - } else { - aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); - aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); - } - const uint8_t * sc = (const uint8_t *)aux; - const uint8_t * m = sc + 2; - - for (int i = 0; i < QR4_K; ++i) { - const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - d8[i] = bq8i->ds[0]; + const block_q4_K * bq4_K = (const block_q4_K *) vbq; - const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); - u[2*i+0] = q8[0]; - u[2*i+1] = q8[4]; - } + const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); + const int * q4 = (const int *) (bq4_K->qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); + const uint16_t * scales = (const uint16_t *) bq4_K->scales; - return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); + return vec_dot_q4_K_q8_1_common(q4, scales, bq4_K->dm, bq8_1, iqs); #else From 4d8d4f87d11009594b08f032e715621a4cd732bc Mon Sep 17 00:00:00 2001 From: Atharva Dubey Date: Thu, 15 May 2025 16:39:52 +0100 Subject: [PATCH 15/24] sycl: simplify bin_bcast_kernel (llama/13383) --- ggml/src/ggml-sycl/binbcast.cpp | 353 +++++++++++--------------------- 1 file changed, 121 insertions(+), 232 deletions(-) diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp index 0a9d3a927c2..aaa94176f16 100644 --- a/ggml/src/ggml-sycl/binbcast.cpp +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -1,93 +1,74 @@ #include "binbcast.hpp" +#include #include #include #include +#include "dpct/helper.hpp" #include "ggml.h" -template -static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13, - const sycl::nd_item<3> &item_ct1) { - const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1)); - const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + - item_ct1.get_local_id(0)) / - ne3; - const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + - item_ct1.get_local_id(0)) % - ne3; - - if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { - return; - } - - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; - - const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; - const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i3*s3 + i2*s2 + i1*s1; - - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; - dst_t * dst_row = dst + i_dst; - - for (int i0 = i0s; i0 < ne0; - i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) { - const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); +template +static __dpct_inline__ void k_bin_bcast_contiguous(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, + dst_t * dst, std::size_t num_elements, const sycl::nd_item<1> & it) { + auto element_id = it.get_global_id(0); + auto global_range = it.get_global_range(0); + for (; element_id < num_elements; element_id += global_range) { + auto src0_float_val = sycl::vec(src0[element_id]).template convert(); + auto src1_float_val = sycl::vec(src1[element_id]).template convert(); + float dst_val = bin_op(src0_float_val[0], src1_float_val[0]); + auto val_to_store = sycl::vec(dst_val).template convert(); + dst[element_id] = val_to_store; } } -template -static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13, - const sycl::nd_item<3> &item_ct1) { - - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - const int i3 = i/(ne2*ne1*ne0); - const int i2 = (i/(ne1*ne0)) % ne2; - const int i1 = (i/ne0) % ne1; - const int i0 = i % ne0; - - if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { - return; +template +static __dpct_inline__ void k_bin_bcast(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13, + int s0, int s1, int s2, int s3, int s00, int s01, int s02, int s03, int s10, + int s11, int s12, int s13, std::size_t num_dst_elements, + const sycl::nd_item<1> & item_ct1) { + auto calculate_logical_index = + [](const std::array & dims, std::size_t element_id) __attribute__((always_inline))->std::array { + std::array logical_index; +#pragma unroll(4) + for (int i = 3; i >= 0; i--) { + logical_index[i] = element_id % dims[i]; + element_id /= dims[i]; + } + return logical_index; + }; + + auto calculate_index = [](const std::array & dims, const std::array & strides, + const std::array & indices) __attribute__((always_inline)) + ->std::size_t { + std::size_t index = 0; +#pragma unroll(4) + for (int i = 0; i < 4; i++) { + auto index_i = indices[i]; + if (indices[i] >= dims[i]) { + index_i = indices[i] % dims[i]; + } + index += strides[i] * index_i; + } + return index; + }; + + auto element_id = item_ct1.get_global_id(0); + for (; element_id < num_dst_elements; element_id += item_ct1.get_global_range(0)) { + auto logical_index = calculate_logical_index({ ne3, ne2, ne1, ne0 }, element_id); + auto src_0_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s03, s02, s01, s00 }, logical_index); + auto src_1_index = calculate_index({ ne13, ne12, ne11, ne10 }, { s13, s12, s11, s10 }, logical_index); + auto dst_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s3, s2, s1, s0 }, logical_index); + auto src0_float_val = sycl::vec(src0[src_0_index]).template convert(); + auto src1_float_val = sycl::vec(src1[src_1_index]).template convert(); + float dst_val = bin_op(src0_float_val[0], src1_float_val[0]); + auto val_to_store = sycl::vec(dst_val).template convert(); + dst[dst_index] = val_to_store; } - - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; - - const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; - const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i3*s3 + i2*s2 + i1*s1; - - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; - dst_t * dst_row = dst + i_dst; - - const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); } - -template -struct bin_bcast_sycl { +template struct bin_bcast_sycl { template void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11, @@ -96,165 +77,73 @@ struct bin_bcast_sycl { const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0, const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous, const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) { - int nr0 = ne10 / ne0; - int nr1 = ne11/ne1; - int nr2 = ne12/ne2; - int nr3 = ne13/ne3; - - int nr[4] = { nr0, nr1, nr2, nr3 }; - - // collapse dimensions until first broadcast dimension - int64_t cne[] = {ne0, ne1, ne2, ne3}; - int64_t cne0[] = {ne00, ne01, ne02, ne03}; - int64_t cne1[] = {ne10, ne11, ne12, ne13}; - size_t cnb[] = {nb0, nb1, nb2, nb3}; - size_t cnb0[] = {nb00, nb01, nb02, nb03}; - size_t cnb1[] = {nb10, nb11, nb12, nb13}; - auto collapse = [](int64_t cne[]) { - cne[0] *= cne[1]; - cne[1] = cne[2]; - cne[2] = cne[3]; - cne[3] = 1; - }; - - auto collapse_nb = [](size_t cnb[], int64_t cne[]) { - cnb[1] *= cne[1]; - cnb[2] *= cne[2]; - cnb[3] *= cne[3]; - }; - - if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) { + auto check_bcast_required = [](const std::array & src_dims, + const std::array & dst_dims) -> bool { for (int i = 0; i < 4; i++) { - if (nr[i] != 1) { - break; - } - if (i > 0) { - collapse_nb(cnb, cne); - collapse_nb(cnb0, cne0); - collapse_nb(cnb1, cne1); - collapse(cne); - collapse(cne0); - collapse(cne1); + if (dst_dims[i] > src_dims[i]) { + return true; } } - } - { - int64_t ne0 = cne[0]; - int64_t ne1 = cne[1]; - int64_t ne2 = cne[2]; - int64_t ne3 = cne[3]; - - int64_t ne10 = cne1[0]; - int64_t ne11 = cne1[1]; - int64_t ne12 = cne1[2]; - int64_t ne13 = cne1[3]; - - size_t nb0 = cnb[0]; - size_t nb1 = cnb[1]; - size_t nb2 = cnb[2]; - size_t nb3 = cnb[3]; - - size_t nb00 = cnb0[0]; - size_t nb01 = cnb0[1]; - size_t nb02 = cnb0[2]; - size_t nb03 = cnb0[3]; - - size_t nb10 = cnb1[0]; - size_t nb11 = cnb1[1]; - size_t nb12 = cnb1[2]; - size_t nb13 = cnb1[3]; - - size_t s0 = nb0 / sizeof(dst_t); - size_t s1 = nb1 / sizeof(dst_t); - size_t s2 = nb2 / sizeof(dst_t); - size_t s3 = nb3 / sizeof(dst_t); - - size_t s10 = nb10 / sizeof(src1_t); - size_t s11 = nb11 / sizeof(src1_t); - size_t s12 = nb12 / sizeof(src1_t); - size_t s13 = nb13 / sizeof(src1_t); - - size_t s00 = nb00 / sizeof(src0_t); - size_t s01 = nb01 / sizeof(src0_t); - size_t s02 = nb02 / sizeof(src0_t); - size_t s03 = nb03 / sizeof(src0_t); - - GGML_UNUSED(s00); - - GGML_ASSERT(nb0 % sizeof(dst_t) == 0); - GGML_ASSERT(nb1 % sizeof(dst_t) == 0); - GGML_ASSERT(nb2 % sizeof(dst_t) == 0); - GGML_ASSERT(nb3 % sizeof(dst_t) == 0); - - GGML_ASSERT(nb00 % sizeof(src0_t) == 0); - GGML_ASSERT(nb01 % sizeof(src0_t) == 0); - GGML_ASSERT(nb02 % sizeof(src0_t) == 0); - GGML_ASSERT(nb03 % sizeof(src0_t) == 0); - - GGML_ASSERT(nb10 % sizeof(src1_t) == 0); - GGML_ASSERT(nb11 % sizeof(src1_t) == 0); - GGML_ASSERT(nb12 % sizeof(src1_t) == 0); - GGML_ASSERT(nb13 % sizeof(src1_t) == 0); - - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s10 == 1); - - const int block_size = 128; - - int64_t hne0 = std::max(ne0/2LL, 1LL); - - sycl::range<3> block_dims(1, 1, 1); - block_dims[2] = std::min(hne0, block_size); - block_dims[1] = std::min( - ne1, block_size / (unsigned int)block_dims[2]); - block_dims[0] = std::min( - std::min( - ne2 * ne3, block_size / (unsigned int)block_dims[2] / - (unsigned int)block_dims[1]), - 64U); - - sycl::range<3> block_nums( - (ne2 * ne3 + block_dims[0] - 1) / block_dims[0], - (ne1 + block_dims[1] - 1) / block_dims[1], - (hne0 + block_dims[2] - 1) / block_dims[2]); - - if (block_nums[0] > 65535) { - // this is the maximum number of blocks in z direction, fallback to 1D grid kernel - int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * - sycl::range<3>(1, 1, block_size), - sycl::range<3>(1, 1, block_size)), - [=](sycl::nd_item<3> item_ct1) { - k_bin_bcast_unravel( - src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02, - s03, s11, s12, s13, item_ct1); - }); - } - } else { - /* - DPCT1049:16: The work-group size passed to the SYCL kernel may - exceed the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if - needed. - */ - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, - ne2, ne3, ne10, ne11, ne12, ne13, - s1, s2, s3, s01, s02, s03, s11, s12, s13, - item_ct1); - }); - } + return false; + }; + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + GGML_ASSERT(nb0 % sizeof(dst_t) == 0); + GGML_ASSERT(nb1 % sizeof(dst_t) == 0); + GGML_ASSERT(nb2 % sizeof(dst_t) == 0); + GGML_ASSERT(nb3 % sizeof(dst_t) == 0); + + GGML_ASSERT(nb00 % sizeof(src0_t) == 0); + GGML_ASSERT(nb01 % sizeof(src0_t) == 0); + GGML_ASSERT(nb02 % sizeof(src0_t) == 0); + GGML_ASSERT(nb03 % sizeof(src0_t) == 0); + + GGML_ASSERT(nb10 % sizeof(src1_t) == 0); + GGML_ASSERT(nb11 % sizeof(src1_t) == 0); + GGML_ASSERT(nb12 % sizeof(src1_t) == 0); + GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + + // dst strides in number of elements + size_t s0 = nb0 / sizeof(dst_t); + size_t s1 = nb1 / sizeof(dst_t); + size_t s2 = nb2 / sizeof(dst_t); + size_t s3 = nb3 / sizeof(dst_t); + + // src1 strides in number of elements + size_t s10 = nb10 / sizeof(src0_t); + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); + + // src0 strides in number of elements + size_t s00 = nb00 / sizeof(src0_t); + size_t s01 = nb01 / sizeof(src0_t); + size_t s02 = nb02 / sizeof(src0_t); + size_t s03 = nb03 / sizeof(src0_t); + + std::size_t num_dst_elements = static_cast(ne0) * static_cast(ne1) * + static_cast(ne2) * static_cast(ne3); + std::size_t local_range = 256; + std::size_t global_range = ceil_div(num_dst_elements, local_range) * local_range; + + bool needs_broadcasting = check_bcast_required({ ne00, ne01, ne02, ne03 }, { ne0, ne1, ne2, ne3 }) || + check_bcast_required({ ne10, ne11, ne12, ne13 }, { ne0, ne1, ne2, ne3 }); + bool all_contiguous = src0_is_contiguous && src1_is_contiguous && dst_is_contiguous; + + if (! needs_broadcasting && all_contiguous) { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) { + k_bin_bcast_contiguous(src0_dd, src1_dd, dst_dd, num_dst_elements, it); + }); + }); + } else { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) { + k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s0, s1, + s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, num_dst_elements, it); + }); + }); } } }; From 577402deebb98b0f9437741a8f8684bbb750643b Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Thu, 15 May 2025 10:13:11 -0700 Subject: [PATCH 16/24] gguf : use ggml log system (llama/13571) * gguf : use ggml log system * llama : remove unnecessary new lines in exception messages --- ggml/src/gguf.cpp | 66 +++++++++++++++++++++++------------------------ 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index 381a9c7dcbe..8667a80bd06 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -299,10 +299,10 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vectorversion)) { if (ctx->version == 1) { - fprintf(stderr, "%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__); + GGML_LOG_ERROR("%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__); ok = false; } if (ctx->version > GGUF_VERSION) { - fprintf(stderr, "%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n", + GGML_LOG_ERROR("%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n", __func__, ctx->version, GGUF_VERSION); ok = false; } @@ -363,7 +363,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par if (ok && gr.read(n_tensors)) { static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing"); if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) { - fprintf(stderr, "%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n", + GGML_LOG_ERROR("%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n", __func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info)); ok = false; } @@ -374,7 +374,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par if (ok && gr.read(n_kv)) { static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing"); if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) { - fprintf(stderr, "%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n", + GGML_LOG_ERROR("%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n", __func__, n_kv, SIZE_MAX/sizeof(gguf_kv)); ok = false; } @@ -383,7 +383,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par } if (!ok) { - fprintf(stderr, "%s: failed to read header\n", __func__); + GGML_LOG_ERROR("%s: failed to read header\n", __func__); gguf_free(ctx); return nullptr; } @@ -399,15 +399,15 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par try { ok = ok && gr.read(key); } catch (std::length_error &) { - fprintf(stderr, "%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i); + GGML_LOG_ERROR("%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i); ok = false; } catch (std::bad_alloc &) { - fprintf(stderr, "%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i); + GGML_LOG_ERROR("%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i); ok = false; } for (size_t j = 0; ok && j < ctx->kv.size(); ++j) { if (key == ctx->kv[j].key) { - fprintf(stderr, "%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i); + GGML_LOG_ERROR("%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i); ok = false; } } @@ -441,14 +441,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par case GGUF_TYPE_ARRAY: default: { - fprintf(stderr, "%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type); + GGML_LOG_ERROR("%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type); ok = false; } break; } } if (!ok) { - fprintf(stderr, "%s: failed to read key-value pairs\n", __func__); + GGML_LOG_ERROR("%s: failed to read key-value pairs\n", __func__); gguf_free(ctx); return nullptr; } @@ -458,7 +458,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx); if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) { - fprintf(stderr, "%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment); + GGML_LOG_ERROR("%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment); gguf_free(ctx); return nullptr; } @@ -474,14 +474,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par try { ok = ok && gr.read(name); } catch (std::length_error &) { - fprintf(stderr, "%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i); + GGML_LOG_ERROR("%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i); ok = false; } catch (std::bad_alloc &) { - fprintf(stderr, "%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i); + GGML_LOG_ERROR("%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i); ok = false; } if (name.length() >= GGML_MAX_NAME) { - fprintf(stderr, "%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME); + GGML_LOG_ERROR("%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME); ok = false; break; } @@ -490,7 +490,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // make sure there are no duplicate tensor names for (int64_t j = 0; ok && j < i; ++j) { if (strcmp(info.t.name, ctx->info[j].t.name) == 0) { - fprintf(stderr, "%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i); + GGML_LOG_ERROR("%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i); ok = false; break; } @@ -505,7 +505,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par uint32_t n_dims = -1; ok = ok && gr.read(n_dims); if (n_dims > GGML_MAX_DIMS) { - fprintf(stderr, "%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n", + GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n", __func__, info.t.name, n_dims, GGML_MAX_DIMS); ok = false; break; @@ -518,7 +518,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // check that all ne are non-negative if (info.t.ne[j] < 0) { - fprintf(stderr, "%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n", + GGML_LOG_ERROR("%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n", __func__, info.t.name, j, info.t.ne[j]); ok = false; break; @@ -530,7 +530,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par (INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) || (INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) { - fprintf(stderr, "%s: total number of elements in tensor '%s' with shape " + GGML_LOG_ERROR("%s: total number of elements in tensor '%s' with shape " "(%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") is >= %" PRIi64 "\n", __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX); ok = false; @@ -547,7 +547,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // check that tensor type is within defined range if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) { - fprintf(stderr, "%s: tensor '%s' has invalid ggml type %d (%s)\n", + GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d (%s)\n", __func__, info.t.name, info.t.type, ggml_type_name(info.t.type)); ok = false; break; @@ -557,7 +557,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // check that row size is divisible by block size if (blck_size == 0 || info.t.ne[0] % blck_size != 0) { - fprintf(stderr, "%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, " + GGML_LOG_ERROR("%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, " "not a multiple of block size (%" PRId64 ")\n", __func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size); ok = false; @@ -582,7 +582,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par } if (!ok) { - fprintf(stderr, "%s: failed to read tensor info\n", __func__); + GGML_LOG_ERROR("%s: failed to read tensor info\n", __func__); gguf_free(ctx); return nullptr; } @@ -590,7 +590,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // we require the data section to be aligned, so take into account any padding if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) { - fprintf(stderr, "%s: failed to seek to beginning of data section\n", __func__); + GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__); gguf_free(ctx); return nullptr; } @@ -604,9 +604,9 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par for (size_t i = 0; i < ctx->info.size(); ++i) { const gguf_tensor_info & ti = ctx->info[i]; if (ti.offset != ctx->size) { - fprintf(stderr, "%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n", + GGML_LOG_ERROR("%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n", __func__, ti.t.name, ti.offset, ctx->size); - fprintf(stderr, "%s: failed to read tensor data\n", __func__); + GGML_LOG_ERROR("%s: failed to read tensor data\n", __func__); gguf_free(ctx); return nullptr; } @@ -634,7 +634,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par *params.ctx = ggml_init(pdata); if (*params.ctx == nullptr) { - fprintf(stderr, "%s: failed to initialize ggml context for storing tensors\n", __func__); + GGML_LOG_ERROR("%s: failed to initialize ggml context for storing tensors\n", __func__); gguf_free(ctx); return nullptr; } @@ -656,7 +656,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par ok = ok && gr.read(data->data, ctx->size); if (!ok) { - fprintf(stderr, "%s: failed to read tensor data binary blob\n", __func__); + GGML_LOG_ERROR("%s: failed to read tensor data binary blob\n", __func__); ggml_free(ctx_data); *params.ctx = nullptr; gguf_free(ctx); @@ -689,7 +689,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par } if (!ok) { - fprintf(stderr, "%s: failed to create tensors\n", __func__); + GGML_LOG_ERROR("%s: failed to create tensors\n", __func__); ggml_free(ctx_data); *params.ctx = nullptr; gguf_free(ctx); @@ -706,7 +706,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p FILE * file = ggml_fopen(fname, "rb"); if (!file) { - fprintf(stderr, "%s: failed to open GGUF file '%s'\n", __func__, fname); + GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname); return nullptr; } @@ -1305,7 +1305,7 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo FILE * file = ggml_fopen(fname, "wb"); if (!file) { - fprintf(stderr, "%s: failed to open file '%s' for writing GGUF data\n", __func__, fname); + GGML_LOG_ERROR("%s: failed to open file '%s' for writing GGUF data\n", __func__, fname); return false; } From 8bfb65c02eefadc7f92150bf067efe08672de235 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20=C5=9Alusarczyk?= <112692748+lslusarczyk@users.noreply.github.com> Date: Fri, 16 May 2025 12:15:29 +0200 Subject: [PATCH 17/24] sycl : fixed compilation warnings (llama/13582) --- ggml/src/ggml-sycl/element_wise.cpp | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index dcc6ec809a7..becaac4048a 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -655,7 +655,6 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -688,7 +687,6 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -722,7 +720,6 @@ inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -754,7 +751,6 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -786,7 +782,6 @@ inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -818,7 +813,6 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -850,7 +844,6 @@ inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -883,7 +876,6 @@ inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -917,7 +909,6 @@ inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tenso } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -949,7 +940,6 @@ inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -981,7 +971,6 @@ inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1013,7 +1002,6 @@ inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1045,7 +1033,6 @@ inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1078,7 +1065,6 @@ inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1110,7 +1096,6 @@ inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1142,7 +1127,6 @@ inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1174,7 +1158,6 @@ inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1206,7 +1189,6 @@ inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1241,7 +1223,6 @@ inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1273,7 +1254,6 @@ inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1315,7 +1295,6 @@ inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1350,7 +1329,6 @@ inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1388,7 +1366,6 @@ inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * ds } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } From 82f8e5fa8c0b44459297bb13dacbf697288cfcae Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 16 May 2025 20:32:58 +0300 Subject: [PATCH 18/24] metal : add FA-vec kernel for head size 64 (llama/13583) ggml-ci --- ggml/src/ggml-metal/ggml-metal.m | 34 +++++++++++++++++++++++++++- ggml/src/ggml-metal/ggml-metal.metal | 10 ++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index f4b3d9cf592..85dbbcd5d7f 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -415,6 +415,13 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, @@ -1362,6 +1369,13 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction); @@ -4358,7 +4372,7 @@ static bool ggml_metal_encode_node( // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) // for now avoiding mainly to keep the number of templates/kernels a bit lower // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612 - if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) { + if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) { switch (src1->type) { case GGML_TYPE_F16: { @@ -4539,6 +4553,24 @@ static bool ggml_metal_encode_node( use_vec_kernel = true; switch (ne00) { + case 64: + { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } break; case 96: { switch (src1->type) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 122ae597371..e94b6cd7564 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4124,6 +4124,16 @@ kernel void kernel_flash_attn_ext_vec( typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; +template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; From 798a95c3bc1c81b4a993fdc26d8193781bb0c453 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 17 May 2025 15:35:47 +0900 Subject: [PATCH 19/24] vulkan: use scalar FA rather than coopmat2 when N==1 (llama/13554) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0856a112283..fe3669b462c 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5872,10 +5872,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_pipeline *pipelines; bool small_rows = N <= get_fa_num_small_rows(path); + // coopmat1 does not actually support "small rows" (it needs 16 rows). + // So use scalar instead. if (small_rows && path == FA_COOPMAT1) { path = FA_SCALAR; } + // scalar is faster than coopmat2 when N==1 + if (N == 1 && path == FA_COOPMAT2) { + path = FA_SCALAR; + } + bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; switch (path) { From 5378cbb80390aaed374bf8419b5666965ab8fd57 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 17 May 2025 16:14:55 +0900 Subject: [PATCH 20/24] vulkan: move common FA code to flash_attn_base.comp (llama/13556) * vulkan: move common FA code to flash_attn_base.comp * vulkan: move common FA index/stride setup code to flash_attn_base.comp * build fix --- .../vulkan-shaders/flash_attn.comp | 153 +---------------- .../vulkan-shaders/flash_attn_base.comp | 162 ++++++++++++++++++ .../vulkan-shaders/flash_attn_cm1.comp | 152 +--------------- .../vulkan-shaders/flash_attn_cm2.comp | 120 +------------ 4 files changed, 170 insertions(+), 417 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 16835576814..ce230a8f7d9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -9,60 +9,13 @@ #extension GL_KHR_shader_subgroup_shuffle : enable #include "types.comp" +#include "flash_attn_base.comp" -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (constant_id = 0) const uint32_t WorkGroupSize = 128; -layout (constant_id = 1) const uint32_t Br = 1; -layout (constant_id = 2) const uint32_t Bc = 32; -layout (constant_id = 3) const uint32_t D = 32; - -layout (constant_id = 5) const uint32_t D_split = 16; const uint32_t D_per_thread = D / D_split; const uint32_t cols_per_iter = WorkGroupSize / D_split; const uint32_t cols_per_thread = Bc / cols_per_iter; -layout (push_constant) uniform parameter { - uint32_t N; - uint32_t KV; - - uint32_t ne1; - uint32_t ne2; - uint32_t ne3; - - uint32_t neq2; - uint32_t neq3; - uint32_t nek2; - uint32_t nek3; - uint32_t nev2; - uint32_t nev3; - uint32_t nem1; - - uint32_t nb01; - uint32_t nb02; - uint32_t nb03; - uint32_t nb11; - uint32_t nb12; - uint32_t nb13; - uint32_t nb21; - uint32_t nb22; - uint32_t nb23; - uint32_t nb31; - - float scale; - float max_bias; - float logit_softcap; - - uint32_t mask; - uint32_t n_head_log2; - float m0; - float m1; - - uint32_t gqa_ratio; - uint32_t split_kv; - uint32_t k_num; -} p; layout (binding = 0) readonly buffer Q {float data_q[];}; layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; @@ -71,39 +24,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; layout (binding = 2) readonly buffer V {float16_t data_v[];}; layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; -layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; - -#if defined(A_TYPE_PACKED16) -#define BINDING_IDX_K 0 -#define BINDING_IDX_V 1 -layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; -#endif - -#if defined(DATA_A_Q4_0) -#define BLOCK_BYTE_SIZE 18 - -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); -} -#endif - -#if defined(DATA_A_Q8_0) -#define BLOCK_BYTE_SIZE 34 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); -} -#endif - -#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. @@ -114,27 +34,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } -// Store column zero. This is used to save per-row m and L values for split_k. -ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - if (r < N && c == 0) { - uint32_t offset = iq2 + r; - data_o[o_offset + offset] = D_TYPE(elem); - } - return elem; -} - -// Load the slope matrix, indexed by Q's dimension 2. -ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) -{ - const uint32_t h = iq2 + (r % p.gqa_ratio); - - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); - - return ACC_TYPE(pow(base, ACC_TYPE(exph))); -} - shared FLOAT_TYPE tmpsh[WorkGroupSize]; shared vec4 tmpshv4[WorkGroupSize]; @@ -146,58 +45,12 @@ void main() { init_iq_shmem(gl_WorkGroupSize); #endif - const uint32_t tid = gl_LocalInvocationIndex; - const uint32_t N = p.N; - const uint32_t KV = p.KV; + init_indices(); + const uint32_t tid = gl_LocalInvocationIndex; const uint32_t d_tid = gl_LocalInvocationIndex % D_split; const uint32_t col_tid = gl_LocalInvocationIndex / D_split; - uint32_t i = gl_WorkGroupID.x; - uint32_t split_k_index = 0; - - if (p.k_num > 1) { - i = 0; - split_k_index = gl_WorkGroupID.x; - } - - const uint32_t Tr = CEIL_DIV(N, Br); - - const uint32_t start_j = split_k_index * p.split_kv / Bc; - const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); - - // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. - // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. - const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; - const uint32_t iq3 = gl_WorkGroupID.z; - - // broadcast factors - const uint32_t rk2 = p.neq2/p.nek2; - const uint32_t rk3 = p.neq3/p.nek3; - - const uint32_t rv2 = p.neq2/p.nev2; - const uint32_t rv3 = p.neq3/p.nev3; - - // k indices - const uint32_t ik3 = iq3 / rk3; - const uint32_t ik2 = iq2 / rk2; - - // v indices - const uint32_t iv3 = iq3 / rv3; - const uint32_t iv2 = iq2 / rv2; - - // nb?1 are already divided by the type size and are in units of elements. - // When using grouped query attention, Q is indexed by iq2, so the stride - // should be nb02 (which is in bytes). - uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; - uint32_t k_stride = p.nb11; - uint32_t v_stride = p.nb21; - // When using grouped query attention, all rows use the same mask (stride 0). - // "p.gqa_ratio >> 16" is just a roundabout way of writing zero - // that prevents the compiler from folding the "&" through the select - // and breaking the alignment detection. - uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp new file mode 100644 index 00000000000..61d90e2d8ed --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -0,0 +1,162 @@ + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 0) const uint32_t WorkGroupSize = 128; +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t D = 32; +layout (constant_id = 4) const uint32_t Clamp = 0; +layout (constant_id = 5) const uint32_t D_split = 16; + + +layout (push_constant) uniform parameter { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; +} p; + +layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; + +#if defined(A_TYPE_PACKED16) +#define BINDING_IDX_K 0 +#define BINDING_IDX_V 1 +layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; +#endif + +#if defined(DATA_A_Q4_0) +#define BLOCK_BYTE_SIZE 18 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); +} +#endif + +#if defined(DATA_A_Q8_0) +#define BLOCK_BYTE_SIZE 34 +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + +uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, + iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, + q_stride, k_stride, v_stride, m_stride; + +void init_indices() +{ + N = p.N; + KV = p.KV; + + i = gl_WorkGroupID.x; + split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + + Tr = CEIL_DIV(N, Br); + + start_j = split_k_index * p.split_kv / Bc; + end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); + + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + iq2 = gl_WorkGroupID.y * p.gqa_ratio; + iq3 = gl_WorkGroupID.z; + + // broadcast factors + rk2 = p.neq2/p.nek2; + rk3 = p.neq3/p.nek3; + + rv2 = p.neq2/p.nev2; + rv3 = p.neq3/p.nev3; + + // k indices + ik3 = iq3 / rk3; + ik2 = iq2 / rk2; + + // v indices + iv3 = iq3 / rv3; + iv2 = iq2 / rv2; + + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; + k_stride = p.nb11; + v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 8b86b623bd9..da478be24fb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -11,14 +11,7 @@ #extension GL_KHR_cooperative_matrix : enable #include "types.comp" - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (constant_id = 1) const uint32_t Br = 1; -layout (constant_id = 2) const uint32_t Bc = 32; -layout (constant_id = 3) const uint32_t D = 32; - -layout (constant_id = 5) const uint32_t D_split = 16; +#include "flash_attn_base.comp" const uint32_t D_per_thread = D / D_split; const uint32_t row_split = 4; @@ -26,46 +19,6 @@ const uint32_t rows_per_thread = Br / row_split; const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; -layout (push_constant) uniform parameter { - uint32_t N; - uint32_t KV; - - uint32_t ne1; - uint32_t ne2; - uint32_t ne3; - - uint32_t neq2; - uint32_t neq3; - uint32_t nek2; - uint32_t nek3; - uint32_t nev2; - uint32_t nev3; - uint32_t nem1; - - uint32_t nb01; - uint32_t nb02; - uint32_t nb03; - uint32_t nb11; - uint32_t nb12; - uint32_t nb13; - uint32_t nb21; - uint32_t nb22; - uint32_t nb23; - uint32_t nb31; - - float scale; - float max_bias; - float logit_softcap; - - uint32_t mask; - uint32_t n_head_log2; - float m0; - float m1; - - uint32_t gqa_ratio; - uint32_t split_kv; - uint32_t k_num; -} p; layout (binding = 0) readonly buffer Q {float data_q[];}; layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; @@ -74,39 +27,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; layout (binding = 2) readonly buffer V {float16_t data_v[];}; layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; -layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; - -#if defined(A_TYPE_PACKED16) -#define BINDING_IDX_K 0 -#define BINDING_IDX_V 1 -layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; -#endif - -#if defined(DATA_A_Q4_0) -#define BLOCK_BYTE_SIZE 18 - -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); -} -#endif - -#if defined(DATA_A_Q8_0) -#define BLOCK_BYTE_SIZE 34 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); -} -#endif - -#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. @@ -117,27 +37,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } -// Store column zero. This is used to save per-row m and L values for split_k. -ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - if (r < N && c == 0) { - uint32_t offset = iq2 + r; - data_o[o_offset + offset] = D_TYPE(elem); - } - return elem; -} - -// Load the slope matrix, indexed by Q's dimension 2. -ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) -{ - const uint32_t h = iq2 + (r % p.gqa_ratio); - - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); - - return ACC_TYPE(pow(base, ACC_TYPE(exph))); -} - // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd const uint32_t MatBr = 16; const uint32_t MatBc = 16; @@ -162,9 +61,9 @@ void main() { init_iq_shmem(gl_WorkGroupSize); #endif + init_indices(); + const uint32_t tid = gl_LocalInvocationIndex; - const uint32_t N = p.N; - const uint32_t KV = p.KV; const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; @@ -173,51 +72,6 @@ void main() { #define tile_row(r) (row_tid * rows_per_thread + (r)) - uint32_t i = gl_WorkGroupID.x; - uint32_t split_k_index = 0; - - if (p.k_num > 1) { - i = 0; - split_k_index = gl_WorkGroupID.x; - } - - const uint32_t Tr = CEIL_DIV(N, Br); - - const uint32_t start_j = split_k_index * p.split_kv / Bc; - const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); - - // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. - // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. - const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; - const uint32_t iq3 = gl_WorkGroupID.z; - - // broadcast factors - const uint32_t rk2 = p.neq2/p.nek2; - const uint32_t rk3 = p.neq3/p.nek3; - - const uint32_t rv2 = p.neq2/p.nev2; - const uint32_t rv3 = p.neq3/p.nev3; - - // k indices - const uint32_t ik3 = iq3 / rk3; - const uint32_t ik2 = iq2 / rk2; - - // v indices - const uint32_t iv3 = iq3 / rv3; - const uint32_t iv2 = iq2 / rv2; - - // nb?1 are already divided by the type size and are in units of elements. - // When using grouped query attention, Q is indexed by iq2, so the stride - // should be nb02 (which is in bytes). - uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; - uint32_t k_stride = p.nb11; - uint32_t v_stride = p.nb21; - // When using grouped query attention, all rows use the same mask (stride 0). - // "p.gqa_ratio >> 16" is just a roundabout way of writing zero - // that prevents the compiler from folding the "&" through the select - // and breaking the alignment detection. - uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index b926a578ade..6acf67a03a4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -18,62 +18,12 @@ #include "types.comp" #include "dequant_funcs_cm2.comp" - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (constant_id = 1) const uint32_t Br = 32; -layout (constant_id = 2) const uint32_t Bc = 32; -layout (constant_id = 3) const uint32_t D = 32; -layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV; - -layout (push_constant) uniform parameter { - uint32_t N; - uint32_t KV; - - uint32_t ne1; - uint32_t ne2; - uint32_t ne3; - - uint32_t neq2; - uint32_t neq3; - uint32_t nek2; - uint32_t nek3; - uint32_t nev2; - uint32_t nev3; - uint32_t nem1; - - uint32_t nb01; - uint32_t nb02; - uint32_t nb03; - uint32_t nb11; - uint32_t nb12; - uint32_t nb13; - uint32_t nb21; - uint32_t nb22; - uint32_t nb23; - uint32_t nb31; - - float scale; - float max_bias; - float logit_softcap; - - uint32_t mask; - uint32_t n_head_log2; - float m0; - float m1; - - uint32_t gqa_ratio; - uint32_t split_kv; - uint32_t k_num; -} p; +#include "flash_attn_base.comp" layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; layout (binding = 2) readonly buffer V {uint8_t data_v[];}; layout (binding = 3) readonly buffer M {uint8_t data_m[];}; -layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; - -#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { return max(x, y); @@ -118,67 +68,12 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } -// Store column zero. This is used to save per-row m and L values for split_k. -ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - if (r < N && c == 0) { - uint32_t offset = iq2 + r; - data_o[o_offset + offset] = D_TYPE(elem); - } - return elem; -} - -// Load the slope matrix, indexed by Q's dimension 2. -ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) -{ - const uint32_t h = iq2 + (r % p.gqa_ratio); - - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); - - return ACC_TYPE(pow(base, ACC_TYPE(exph))); -} - void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif - const uint32_t N = p.N; - const uint32_t KV = p.KV; - - uint32_t i = gl_WorkGroupID.x; - uint32_t split_k_index = 0; - - if (p.k_num > 1) { - i = 0; - split_k_index = gl_WorkGroupID.x; - } - - const uint32_t Tr = CEIL_DIV(N, Br); - - const uint32_t start_j = split_k_index * p.split_kv / Bc; - const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); - - // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. - // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. - const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; - const uint32_t iq3 = gl_WorkGroupID.z; - - // broadcast factors - const uint32_t rk2 = p.neq2/p.nek2; - const uint32_t rk3 = p.neq3/p.nek3; - - const uint32_t rv2 = p.neq2/p.nev2; - const uint32_t rv3 = p.neq3/p.nev3; - - // k indices - const uint32_t ik3 = iq3 / rk3; - const uint32_t ik2 = iq2 / rk2; - - // v indices - const uint32_t iv3 = iq3 / rv3; - const uint32_t iv2 = iq2 / rv2; + init_indices(); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); @@ -195,17 +90,6 @@ void main() { tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); - // nb?1 are already divided by the type size and are in units of elements. - // When using grouped query attention, Q is indexed by iq2, so the stride - // should be nb02 (which is in bytes). - uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; - uint32_t k_stride = p.nb11; - uint32_t v_stride = p.nb21; - // When using grouped query attention, all rows use the same mask (stride 0). - // "p.gqa_ratio >> 16" is just a roundabout way of writing zero - // that prevents the compiler from folding the "&" through the select - // and breaking the alignment detection. - uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; // hint to the compiler that strides are aligned for the aligned variant of the shader if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { From 2d07c5613d399f87cc783a91620e2691a586a873 Mon Sep 17 00:00:00 2001 From: "Gilad S." <7817232+giladgd@users.noreply.github.com> Date: Sat, 17 May 2025 21:26:43 +0300 Subject: [PATCH 21/24] cmake: use the current build config for vulkan-shaders-gen (llama/13595) * fix: use the current build config for `vulkan-shaders-gen` * fix: only pass a valid build type to `--config` --- ggml/src/ggml-vulkan/CMakeLists.txt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 16e10a9f399..662f1377107 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -54,6 +54,11 @@ if (Vulkan_FOUND) -DCMAKE_RUNTIME_OUTPUT_DIRECTORY=${CMAKE_RUNTIME_OUTPUT_DIRECTORY} ) + set(VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS "") + if (CMAKE_BUILD_TYPE AND CMAKE_BUILD_TYPE MATCHES "Debug|Release|MinSizeRel|RelWithDebInfo") + list(APPEND VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS --config=${CMAKE_BUILD_TYPE}) + endif() + # Test all shader extensions test_shader_extension_support( "GL_KHR_cooperative_matrix" @@ -149,7 +154,7 @@ if (Vulkan_FOUND) vulkan-shaders-gen SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders CMAKE_ARGS ${VULKAN_SHADER_GEN_CMAKE_ARGS} - BUILD_COMMAND ${CMAKE_COMMAND} --build . + BUILD_COMMAND ${CMAKE_COMMAND} --build . ${VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS} INSTALL_COMMAND ${CMAKE_COMMAND} --install . INSTALL_DIR ${CMAKE_BINARY_DIR} ) From 6dedddb5564fef27fccb223b48ed317c286c2eda Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Mon, 19 May 2025 14:21:17 +0800 Subject: [PATCH 22/24] CANN: Support MOE Model MUL_MAT_ID (llama/13042) Signed-off-by: noemotiovon <757486878@qq.com> --- ggml/src/ggml-cann/aclnn_ops.cpp | 147 +++++++++++++++++++++++++++++++ ggml/src/ggml-cann/aclnn_ops.h | 27 ++++++ ggml/src/ggml-cann/ggml-cann.cpp | 11 ++- 3 files changed, 183 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 67c0223c010..cbf9783b744 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -65,6 +65,7 @@ #include #include #include +#include #include #include @@ -2587,3 +2588,149 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_cann_release_resources(ctx, acl_src, acl_dst, alpha); } + +/** + * @brief Performs expert-specific matrix multiplication (MoE) with + * floating-point precision using the CANN backend. + * + * This function executes a matrix multiplication operation tailored for + * Mixture of Experts (MoE) models, where the input tensor is multiplied + * with expert-specific weight matrices. It uses the CANN backend for + * efficient computation and stores the result in the destination tensor `dst`. + * The operation may leverage identity-based optimizations or routing masks + * as part of sparse expert selection. + * + * @param ctx The context for executing CANN backend operations. + * @param dst The destination tensor where the MoE multiplication result + * will be stored. + * + * @note This function assumes floating-point data types and is designed for + * MoE architectures, possibly involving sparse expert routing. + */ +static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + //dst [M, K, N, 1] + ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] + ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 + ggml_tensor * ids = dst->src[2]; //ids [K, N] + + GGML_TENSOR_BINARY_OP_LOCALS + + // copy index from npu to cpu + int64_t n_as = ne02; // A + int64_t n_ids = ids->ne[0]; // K + + std::vector ids_host(ggml_nbytes(ids)); + ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids), + ACL_MEMCPY_DEVICE_TO_HOST); + ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + + char * src0_original = (char *) src0->data; + char * src1_original = (char *) src1->data; + char * dst_original = (char *) dst->data; + size_t ori_src0_nb[4] = {nb00, nb01, nb02, nb03}; + + // src0 is F16, src1 is F32, dst is F32 + ggml_cann_pool_alloc src0_cast_allocator; + if (src0->type == GGML_TYPE_F16) { + src0_cast_allocator.alloc(ctx.pool(), sizeof(float) * ggml_nelements(src0)); + void* src0_cast_buf = src0_cast_allocator.get(); + + size_t cast_nb[GGML_MAX_DIMS]; + cast_nb[0] = sizeof(float_t); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + cast_nb[i] = cast_nb[i - 1] * src0->ne[i - 1]; + } + + aclTensor* acl_src0_f16 = ggml_cann_create_tensor(src0); + aclTensor* acl_cast = ggml_cann_create_tensor(src0_cast_buf, + ACL_FLOAT, sizeof(float), src0->ne, cast_nb, 4); + GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src0_f16, ACL_FLOAT, acl_cast); + ggml_cann_release_resources(ctx, acl_cast, acl_src0_f16); + + src0_original = (char *) src0_cast_buf; + memcpy(ori_src0_nb, cast_nb, sizeof(ori_src0_nb)); + } + + std::vector src0_tensor_vec; + std::vector src1_tensor_vec; + std::vector dst_tensor_vec; + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + // src0_row [M, D] -> weight && permute + int64_t src0_ne[2] = {ne01, ne00}; + size_t src0_nb[2] = {ori_src0_nb[1], ori_src0_nb[0]}; + // src1_row [D, 1] -> input + int64_t src1_ne[2] = {ne10, 1}; + size_t src1_nb[2] = {nb10, nb11}; + // dst_row [M, 1] -> out + int64_t dst_ne[2] = {ne0, 1}; + size_t dst_nb[2] = {nb0, nb1}; + + // expert index + int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + // If B = 1 (broadcast), always use 0; otherwise, use id. + int64_t i11 = (ne11 == 1 ? 0 : id); + int64_t i12 = iid1; + + int64_t i1 = id; + int64_t i2 = i12; + + void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2]; + void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12; + void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2; + + aclTensor* acl_src0 = ggml_cann_create_tensor(src0_tmp_ptr, + ACL_FLOAT, sizeof(float), + src0_ne, src0_nb, 2); + aclTensor* acl_src1 = ggml_cann_create_tensor(src1_tmp_ptr, + ACL_FLOAT, sizeof(float), + src1_ne, src1_nb, 2); + aclTensor* acl_dst = ggml_cann_create_tensor(dst_tmp_ptr, + ACL_FLOAT, sizeof(float), + dst_ne, dst_nb, 2); + + src0_tensor_vec.push_back(acl_src0); + src1_tensor_vec.push_back(acl_src1); + dst_tensor_vec.push_back(acl_dst); + } + } + + // GroupedMatmulV2 required tensor_list.size < 128 + size_t GROUP_SIZE = 128; + std::vector> src0_tensor_vec_vec; + std::vector> src1_tensor_vec_vec; + std::vector> dst_tensor_vec_vec; + + // split and call GroupedMatmulV2 + for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) { + size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size()); + std::vector src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end); + std::vector src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end); + std::vector dst_tensor_vec_split(dst_tensor_vec.begin() + i, dst_tensor_vec.begin() + end); + + aclTensorList* src0_tensor_list = aclCreateTensorList(src0_tensor_vec_split.data(), src0_tensor_vec_split.size()); + aclTensorList* src1_tensor_list = aclCreateTensorList(src1_tensor_vec_split.data(), src1_tensor_vec_split.size()); + aclTensorList* dst_tensor_list = aclCreateTensorList(dst_tensor_vec_split.data(), dst_tensor_vec_split.size()); + + GGML_CANN_CALL_ACLNN_OP(ctx, GroupedMatmulV2, src1_tensor_list, src0_tensor_list, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, -1, dst_tensor_list); + + ggml_cann_release_resources(ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list); + } + return; +} + +void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + const enum ggml_type type = dst->src[0]->type; + switch (type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + ggml_cann_mul_mat_id_fp(ctx, dst); + break; + default: + GGML_ABORT("Unsupported type for mul_mat_id"); + break; + } +} diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 462351542e5..15993cce66f 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -978,6 +978,33 @@ inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffe } } +/** + * @brief Performs sparse expert-based matrix multiplication using the CANN backend. + * + * @details This function implements a MoE-style batched matrix multiplication, where each input token + * is routed to one or more experts, and each expert corresponds to a specific [D, M] weight matrix + * in the source tensor `src0`. The routing indices are provided via the `ids` tensor. + * + * For each token (from `src1`), the function selects the corresponding expert(s) as specified by `ids`, + * performs the matrix multiplication with the selected expert's weight submatrix (from `src0`), + * and stores the results in `dst`. This operation is optimized and executed on the CANN backend. + * + * Dimensions: + * - src0: [D, M, A, 1], where A is the number of experts + * - src1: [D, B, N, 1], where N is batch size and B is the slot count per sample + * - ids : [K, N], where K is the number of experts each token is routed to + * - dst : [M, K, N, 1], output tensor storing the result of expert × token multiplication + * + * The function handles two main modes: + * - If `ne12 == 1`, a simpler per-token loop is used. + * - TODO: If `ne12 > 1`, grouped multiplication and memory copying is used for efficiency. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the expert-weighted token outputs are stored. + * Expected to be of shape [M, K, N, 1]. + */ +void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst); + /** * @brief Applies a element-wise operation to two input tensors using the CANN * backend. diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index e2617b06e9c..0cb7bbf17cc 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1672,7 +1672,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, ggml_cann_mul_mat(ctx, dst); break; case GGML_OP_MUL_MAT_ID: - return false; + ggml_cann_mul_mat_id(ctx, dst); + break; case GGML_OP_SCALE: ggml_cann_scale(ctx, dst); break; @@ -2030,7 +2031,13 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, } } case GGML_OP_MUL_MAT_ID: - return false; + switch (op->src[0]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + default: + return false; + } // embedding case GGML_OP_GET_ROWS: { switch (op->src[0]->type) { From abcf4b264086b0ad6c167ebba7a13bd1382fd92b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 19 May 2025 13:38:44 +0300 Subject: [PATCH 23/24] sync : ggml ggml-ci --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 2ad2ea1c651..6dbde75a843 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -148b286332db1259dcd299c04047a1fd31b02713 +c6202093c3fb4ce8f728d86838400b35cc01ac7c From e5ededd97893b0d61282ca54e1d1ea09f8094121 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 19 May 2025 13:39:12 +0300 Subject: [PATCH 24/24] talk-llama : sync llama.cpp ggml-ci --- examples/talk-llama/llama-arch.cpp | 3 + examples/talk-llama/llama-context.cpp | 6 +- examples/talk-llama/llama-kv-cache.cpp | 8 + examples/talk-llama/llama-kv-cache.h | 14 +- examples/talk-llama/llama-model-loader.cpp | 19 +- examples/talk-llama/llama-model.cpp | 240 ++++++++++++++++++--- examples/talk-llama/llama-quant.cpp | 24 ++- examples/talk-llama/llama.cpp | 5 + 8 files changed, 256 insertions(+), 63 deletions(-) diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index f2bc8ca7685..abf436adac4 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -1481,6 +1481,9 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, { diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 62246c10dab..a3b84a6a82e 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -1704,10 +1704,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { } } - LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); llama_kv_cache * kv_self = static_cast(memory.get()); - kv_self->state_write(io); + if (kv_self != nullptr) { + LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); + kv_self->state_write(io); + } return io.n_bytes(); } diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index 3dcad65bb6a..265db2527c7 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -441,6 +441,13 @@ void llama_kv_cache_unified::defrag_sched(float thold) { void llama_kv_cache_unified::set_full() { n = size; + + // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not + // affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views. + // we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so + // setting it to 0 is the simplest way to achieve that + // ref: https://github.com/ggml-org/llama.cpp/issues/13359 + head = 0; } llama_sbatch llama_kv_cache_unified::sbatch_init( @@ -1712,6 +1719,7 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) { void llama_kv_cache_recurrent::set_full() { n = size; + head = 0; } llama_sbatch llama_kv_cache_recurrent::sbatch_init( diff --git a/examples/talk-llama/llama-kv-cache.h b/examples/talk-llama/llama-kv-cache.h index bf3b4b6a443..e83e12c09f2 100644 --- a/examples/talk-llama/llama-kv-cache.h +++ b/examples/talk-llama/llama-kv-cache.h @@ -171,11 +171,8 @@ class llama_kv_cache_unified : public llama_kv_cache { void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - // Note: The value of head isn't only used to optimize searching - // for a free KV slot. llama_decode_impl also uses it, so it - // cannot be freely changed after a slot has been allocated. - uint32_t head = 0; - uint32_t size = 0; + uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) + uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id) // computed before each graph build @@ -343,11 +340,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache { void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - // Note: The value of head isn't only used to optimize searching - // for a free KV slot. llama_decode_impl also uses it, so it - // cannot be freely changed after a slot has been allocated. - uint32_t head = 0; - uint32_t size = 0; + uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) + uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id) // computed before each graph build diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index 4cce51668b4..ddb1b03675b 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -469,7 +469,7 @@ llama_model_loader::llama_model_loader( meta.reset(gguf_init_from_file(fname.c_str(), params)); if (!meta) { - throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str())); + throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str())); } get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); @@ -528,7 +528,7 @@ llama_model_loader::llama_model_loader( }; gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) }; if (!ctx_gguf) { - throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname_split)); + throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split)); } // check idx @@ -822,13 +822,18 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps mappings.reserve(files.size()); mmaps_used.reserve(files.size()); for (const auto & file : files) { - auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); - if (!reg) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); + bool is_numa = false; + + auto * dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (dev) { + auto * reg = ggml_backend_dev_backend_reg(dev); + auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); + if (is_numa_fn) { + is_numa = is_numa_fn(); + } } - auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); - std::unique_ptr mapping = std::make_unique(file.get(), prefetch ? -1 : 0, is_numa_fn()); + std::unique_ptr mapping = std::make_unique(file.get(), prefetch ? -1 : 0, is_numa); mmaps_used.emplace_back(mapping->size(), 0); if (mlock_mmaps) { std::unique_ptr mlock_mmap(new llama_mlock()); diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 3a4e72a36b0..7fd094b63f2 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -1389,6 +1389,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { // Add additional layer/vocab/etc checks here for other model sizes default: type = LLM_TYPE_UNKNOWN; } + + // For Granite MoE Shared + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); } break; case LLM_ARCH_CHAMELEON: { @@ -1772,6 +1775,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } } } } break; @@ -4385,10 +4395,13 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } - if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE) { + if (arch == LLM_ARCH_MINICPM || + arch == LLM_ARCH_GRANITE || + arch == LLM_ARCH_GRANITE_MOE) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } if (arch == LLM_ARCH_BAILINGMOE) { @@ -4598,11 +4611,6 @@ struct llm_build_llama : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); @@ -4674,11 +4682,6 @@ struct llm_build_llama : public llm_graph_context { cb(cur, "ffn_moe_out", il); } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -4701,11 +4704,6 @@ struct llm_build_llama : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); - // For Granite architecture - if (hparams.f_logit_scale) { - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); - } - cb(cur, "result_output", -1); res->t_logits = cur; @@ -4816,11 +4814,6 @@ struct llm_build_deci : public llm_graph_context { continue; } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - // modified to support attention-free layer of Llama-3_1-Nemotron-51B ggml_tensor * ffn_inp = cur; if (n_head > 0) { @@ -4844,11 +4837,6 @@ struct llm_build_deci : public llm_graph_context { cb(cur, "ffn_out", il); } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -4871,11 +4859,6 @@ struct llm_build_deci : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); - // For Granite architecture - if (hparams.f_logit_scale) { - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); - } - cb(cur, "result_output", -1); res->t_logits = cur; @@ -12214,6 +12197,194 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { } }; + +struct llm_build_granite : public llm_graph_context { + llm_build_granite( + const llama_model & model, + const llm_graph_params & params, + ggml_cgraph * gf, + const bool use_rope = true) + : llm_graph_context(params) { + + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - built only if rope enabled + ggml_tensor * inp_pos = nullptr; + if (use_rope) { + inp_pos = build_inp_pos(); + } + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and (optionally) RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (use_rope) { + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // For Granite architectures - scale residual + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + } + + // For Granite architectures - scale residual + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // For Granite architectures - scale logits + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + // ref: https://github.com/facebookresearch/chameleon // based on the original build_llama() function, changes: // * qk-norm @@ -12921,8 +13092,6 @@ llm_graph_result_ptr llama_model::build_graph( case LLM_ARCH_LLAMA: case LLM_ARCH_LLAMA4: case LLM_ARCH_MINICPM: - case LLM_ARCH_GRANITE: - case LLM_ARCH_GRANITE_MOE: { llm = std::make_unique(*this, params, gf); } break; @@ -13153,6 +13322,11 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_CHAMELEON: { llm = std::make_unique(*this, params, gf); diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 820d5128e29..159b1307a4c 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -14,6 +14,12 @@ #include #include +// Quantization types. Changes to this struct must be replicated in quantize.cpp +struct tensor_quantization { + std::string name; + ggml_type quant = GGML_TYPE_COUNT; +}; + static void zeros(std::ofstream & file, size_t n) { char zero = 0; for (size_t i = 0; i < n; ++i) { @@ -48,12 +54,6 @@ struct quantize_state_impl { {} }; -// changes to this struct must be replicated in quantize.cpp -struct tensor_quantization { - std::string name; - ggml_type quant = GGML_TYPE_COUNT; -}; - static void llama_tensor_dequantize_impl( ggml_tensor * tensor, std::vector> & output, std::vector & workers, const size_t nelements, const int nthread @@ -796,17 +796,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // unless the user specifies a type if (params->tensor_types) { const std::vector & tensor_types = *static_cast *>(params->tensor_types); + const std::string tensor_name(tensor->name); for (const auto & [tname, qtype] : tensor_types) { - if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) { - if (qtype != new_type) { - LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype)); + if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { + if (qtype != new_type) { + LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); + new_type = qtype; + break; // if two or more types are specified for the tensor, first match wins } - new_type = qtype; - break; } } } } + if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { new_type = params->token_embedding_type; } diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index 9fdddf7b071..2f06e0f8ce1 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -140,6 +140,11 @@ static struct llama_model * llama_model_load_from_file_impl( struct llama_model_params params) { ggml_time_init(); + if (!params.vocab_only && ggml_backend_reg_count() == 0) { + LLAMA_LOG_ERROR("%s: no backends are loaded. hint: use ggml_backend_load() or ggml_backend_load_all() to load a backend before calling this function\n", __func__); + return nullptr; + } + unsigned cur_percentage = 0; if (params.progress_callback == NULL) { params.progress_callback_user_data = &cur_percentage;