From 0c4a2291549392d8258d7ee52e1aa43e98274243 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alberto=20Cabrera=20P=C3=A9rez?= Date: Thu, 8 May 2025 10:08:01 +0100 Subject: [PATCH 01/21] sycl: addressing non-contiguous src1 mul_mats (nc and batched) (llama/13343) * sycl: fixed non-contiguous src1 mul_mats (nc and batched) * Fixed wrong static_cast inside kernel --- ggml/src/ggml-sycl/common.hpp | 17 ++-- ggml/src/ggml-sycl/convert.cpp | 66 ++++++++----- ggml/src/ggml-sycl/convert.hpp | 21 ++-- ggml/src/ggml-sycl/ggml-sycl.cpp | 159 +++++++++++++++---------------- 4 files changed, 138 insertions(+), 125 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index c71cc89c09e..69aad938e88 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -114,17 +114,12 @@ static void crash() { GGML_ABORT("SYCL error"); } -#define SYCL_CHECK(err) \ - do { \ - auto err_ = (err); \ - if (err_ != 0) \ - ggml_sycl_error( \ - #err, \ - __func__, \ - __FILE__, \ - __LINE__, \ - "Meet error in this line code!"); \ - } while (0) +#define SYCL_CHECK(err) \ + do { \ + auto err_ = (err); \ + if (err_ != 0) \ + ggml_sycl_error(#err, __func__, __FILE__, __LINE__, "Exception caught in this line of code."); \ + } while (0) #if DPCT_COMPAT_RT_VERSION >= 11100 #define GGML_SYCL_ASSUME(x) __builtin_assume(x) diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 76ac6a4dd1f..b2f8a656933 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -437,41 +437,52 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k } template -static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, - const sycl::nd_item<3> &item_ct1) { +static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, + const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03, + const sycl::nd_item<3> & item_ct1) { + const int64_t work_group_size = item_ct1.get_local_range(2); - const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2); + const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2); + + const int64_t i01 = item_ct1.get_group(1); + const int64_t i02 = item_ct1.get_group(0) % ne02; + const int64_t i03 = item_ct1.get_group(0) / ne02; // make each work-item deal with more elements since sycl global range can not exceed max int - const src_t * x = (const src_t *) vx; - for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) { - y[i] = x[i]; + const src_t * x = static_cast(vx); + const int64_t ix = i03 * s03 + i02 * s02 + i01 * s01; + const int64_t iy = ((i03 * ne02 + i02) * ne01 + i01) * ne00; + +#pragma unroll + for (int64_t i00 = global_id; i00 < ne00; i00 += work_group_size * item_ct1.get_group_range(2)) { + y[iy + i00] = static_cast(x[ix + i00]); } } template -static void convert_unary_sycl(const void *__restrict__ vx, - dst_t *__restrict__ y, const int64_t k, - dpct::queue_ptr stream) { - const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE; +static void convert_unary_nc_sycl(const void * __restrict__ vx, dst_t * __restrict__ y, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s01, const int64_t s02, const int64_t s03, dpct::queue_ptr queue) { + dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 }); + + sycl::range<3> global_size(ne02 * ne03, ne01, ceil_div(ne00, SYCL_DEQUANTIZE_BLOCK_SIZE)); // decrease global range when it exceeds the max int - int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE); - sycl::range<3> block_nums(1, 1, num_blocks); - sycl::range<3> local_range(1, 1, local_size); - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); + // TODO: Downsample logic is separated from the kernel, a rewrite is desirable + int64_t downsized_workgroup = downsample_sycl_global_range(global_size[0], SYCL_DEQUANTIZE_BLOCK_SIZE); + sycl::range<3> workgroup_size(1, 1, downsized_workgroup); - stream->parallel_for( - sycl::nd_range<3>(block_nums * local_range, local_range), - [=](sycl::nd_item<3> item_ct1) { - convert_unary(vx, y, k, item_ct1); - }); - } + queue->parallel_for(sycl::nd_range<3>(global_size * workgroup_size, workgroup_size), [=](sycl::nd_item<3> item_ct1) { + convert_unary_nc(vx, y, ne00, ne01, ne02, s01, s02, s03, item_ct1); + }); } -to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst) { +template +static void convert_unary_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr queue) { + convert_unary_nc_sycl(vx, y, k, 1, 1, 1, k, k, k, queue); +} + +to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { switch (type) { case GGML_TYPE_Q4_0: if (dst->src[0]->extra && @@ -574,3 +585,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return nullptr; } } + +to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_nc_sycl; + default: + return nullptr; + } +} diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp index 355dae22b40..f8cb573e368 100644 --- a/ggml/src/ggml-sycl/convert.hpp +++ b/ggml/src/ggml-sycl/convert.hpp @@ -1,6 +1,6 @@ // // MIT license -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: MIT // @@ -16,12 +16,19 @@ #include "common.hpp" template -using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y, - int64_t k, dpct::queue_ptr stream); -typedef to_t_sycl_t to_fp32_sycl_t; +using to_t_sycl_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, dpct::queue_ptr stream); +typedef to_t_sycl_t to_fp32_sycl_t; typedef to_t_sycl_t to_fp16_sycl_t; -to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst); -to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst); +to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst); +to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst); -#endif // GGML_SYCL_CONVERT_HPP +// Nc = Non-contiguous +template +using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, + int64_t s01, int64_t s02, int64_t s03, dpct::queue_ptr queue); + +typedef to_t_nc_sycl_t to_fp16_nc_sycl_t; +to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type); + +#endif // GGML_SYCL_CONVERT_HPP diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index ea5d10f40ee..68a26fa481d 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2694,35 +2694,31 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void k_compute_batched_ptrs(const sycl::half *src0_as_f16, - const sycl::half *src1_as_f16, char *dst, - const void **ptrs_src, void **ptrs_dst, - int64_t ne12, int64_t ne13, int64_t ne23, - size_t nb02, size_t nb03, size_t nb12, - size_t nb13, size_t nbd2, size_t nbd3, - int64_t r2, int64_t r3, - const sycl::nd_item<3> &item_ct1) { - int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + - item_ct1.get_local_id(2); - int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + - item_ct1.get_local_id(1); +static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst, + const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23, + size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3, + int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) { + const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); + const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); if (i13 >= ne13 || i12 >= ne12) { return; } - int64_t i03 = i13 / r3; - int64_t i02 = i12 / r2; + const int64_t i03 = i13 / r3; + const int64_t i02 = i12 / r2; + + const uint8_t * src0_bytes = reinterpret_cast(src0_as_f16); + const uint8_t * src1_bytes = reinterpret_cast(src1_as_f16); + uint8_t * dst_bytes = reinterpret_cast(dst); - ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03; - ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13; - ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3; + ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03; + ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13; + ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3; } -static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, - const ggml_tensor *src0, - const ggml_tensor *src1, - ggml_tensor *dst) try { +static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, + const ggml_tensor * src1, ggml_tensor * dst) try { GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); @@ -2730,102 +2726,100 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, GGML_TENSOR_BINARY_OP_LOCALS + // TODO: see https://github.com/ggml-org/llama.cpp/pull/13155 + // Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst + GGML_ASSERT(ggml_is_contiguous(dst)); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - queue_ptr main_stream = ctx.stream();; + queue_ptr queue = ctx.stream(); - void * src0_ddq = src0->data; - sycl::half *src0_as_f16 = (sycl::half *)src0_ddq; - float * src1_ddf = (float *) src1->data; - float * dst_ddf = (float *) dst->data; + dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 }); - // convert src1 to fp16 + const sycl::half * src0_f16 = static_cast(src0->data); + float * dst_ddf = static_cast(dst->data); + + const sycl::half * src1_f16 = static_cast(src1->data); + const size_t type_size_src1 = ggml_type_size(src1->type); + GGML_ASSERT(nb10 == type_size_src1); + + // SRC1 strides + int64_t s11 = nb11 / type_size_src1; + int64_t s12 = nb12 / type_size_src1; + int64_t s13 = nb13 / type_size_src1; ggml_sycl_pool_alloc src1_f16_alloc(ctx.pool()); + + // convert src1 to fp16 if (src1->type != GGML_TYPE_F16) { - const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); + const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type); + GGML_ASSERT(to_fp16_nc_sycl != nullptr); const int64_t ne_src1 = ggml_nelements(src1); src1_f16_alloc.alloc(ne_src1); - GGML_ASSERT(to_fp16_sycl != nullptr); - to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream); + to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue); + + src1_f16 = src1_f16_alloc.get(); + s11 = ne10; + s12 = ne11 * s11; + s13 = ne12 * s12; } - sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf - : src1_f16_alloc.get(); - char * dst_t; + ggml_sycl_pool_alloc dst_f16(ctx.pool()); + char * dst_t = reinterpret_cast(dst_ddf); - dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float; - dpct::library_data_t cu_data_type = dpct::library_data_t::real_float; + dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float; + dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float; // dst strides size_t nbd2 = dst->nb[2]; size_t nbd3 = dst->nb[3]; const float alpha_f32 = 1.0f; - const float beta_f32 = 0.0f; + const float beta_f32 = 0.0f; const void * alpha = &alpha_f32; const void * beta = &beta_f32; - dst_t = (char *) dst_ddf; - GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); // broadcast factors - const int64_t r2 = ne12/ne02; - const int64_t r3 = ne13/ne03; + const int64_t r2 = ne12 / ne02; + const int64_t r3 = ne13 / ne03; if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 - SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, - (const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00, - (const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t, - cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type))); + SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans, + oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, + src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00, + src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t, + mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type))); } else { - const int ne23 = ne12*ne13; + const int ne23 = ne12 * ne13; - ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2*ne23); - ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); + ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2 * ne23); + ggml_sycl_pool_alloc ptrs_dst(ctx.pool(), 1 * ne23); ggml_sycl_pool_alloc> matrix_info(ctx.host_pool(), 1); sycl::range<3> block_dims(1, ne12, ne13); - /* - DPCT1049:47: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(main_stream->get_device(), - {sycl::aspect::fp16}); - - main_stream->submit([&](sycl::handler &cgh) { - const void **ptrs_src_get = ptrs_src.get(); - void **ptrs_dst_get = ptrs_dst.get(); - size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2; - size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2; - cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_compute_batched_ptrs( - src0_as_f16, src1_f16, - dst_t, ptrs_src_get, - ptrs_dst_get, ne12, ne13, ne23, - nb02, nb03, nb12_scaled, nb13_scaled, - nbd2, nbd3, r2, r3, item_ct1); - }); + queue->submit([&](sycl::handler & cgh) { + const void ** ptrs_src_get = ptrs_src.get(); + void ** ptrs_dst_get = ptrs_dst.get(); + size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half); + size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half); + cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02, + nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1); }); - } + }); + SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, + *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, - (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta, - (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get()))); + (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta, + (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get()))); } -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); +} catch (const sycl::exception & exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); } inline bool ggml_sycl_supports_mmq(enum ggml_type type) { @@ -2966,7 +2960,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor // The kernel from the if path is faster for that specific case, but does not support all mul mats. ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } - } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst); } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { @@ -3873,9 +3867,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g if (a->ne[3] != b->ne[3]) { return false; } - if (!ggml_is_contiguous(b)) { - return false; - } ggml_type a_type = a->type; if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS || a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S || From 19d8d9a928118c1b34fd49d73eaf1fc0c2613e4f Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 9 May 2025 02:23:41 -0500 Subject: [PATCH 02/21] vulkan: Allow up to 4096 elements for mul_mat_id row_ids (llama/13326) This assert fired running Qwen_Qwen3-30B-A3B-Q2_K.gguf: GGML_ASSERT(nei0 * nei1 <= 3072); The tensor is 8 x 512. Increase this array size to accommodate. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 ++-- ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c61a8cf0af4..2dc2883a70c 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1632,7 +1632,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec const uint32_t warps = warptile[0] / warptile[10]; const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; - const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0; + const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0; const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size; @@ -5260,7 +5260,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t nei0 = ids->ne[0]; const uint64_t nei1 = ids->ne[1]; - GGML_ASSERT(nei0 * nei1 <= 3072); + GGML_ASSERT(nei0 * nei1 <= 4096); const uint32_t nbi1 = ids->nb[1]; const uint32_t nbi2 = ids->nb[2]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 529ac4d44fe..7859a1a60e2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -103,7 +103,7 @@ shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; #ifdef MUL_MAT_ID -shared u16vec2 row_ids[3072]; +shared u16vec2 row_ids[4096]; #endif // MUL_MAT_ID #define NUM_WARPS (BLOCK_SIZE / WARP) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 344b466101b..91846575732 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -92,7 +92,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #ifdef MUL_MAT_ID layout (binding = 3) readonly buffer IDS {int data_ids[];}; -shared u16vec4 row_ids[3072]; +shared u16vec4 row_ids[4096]; layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { B_TYPE b[]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 284a35caa68..83de90eb7e0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -101,7 +101,7 @@ shared FLOAT_TYPE_VEC2 buf_b_ds[BN]; #define LOAD_VEC_B 4 #ifdef MUL_MAT_ID -shared u16vec2 row_ids[3072]; +shared u16vec2 row_ids[4096]; #endif // MUL_MAT_ID #define NUM_WARPS (BLOCK_SIZE / WARP) From 00c805671581dbe6eb487a6f3c90a27682ad2210 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Fri, 9 May 2025 10:31:07 +0300 Subject: [PATCH 03/21] rpc : add rpc_msg_set_tensor_hash_req (llama/13353) * rpc : add rpc_msg_set_tensor_hash_req Use a dedicated struct for the request of RPC_CMD_SET_TENSOR_HASH which makes the code cleaner. * fix --- ggml/src/ggml-rpc/ggml-rpc.cpp | 52 ++++++++++++++++------------------ 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 039214dc2bc..4f0abb5a60f 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -151,6 +151,12 @@ struct rpc_msg_buffer_clear_req { uint8_t value; }; +struct rpc_msg_set_tensor_hash_req { + rpc_tensor tensor; + uint64_t offset; + uint64_t hash; +}; + struct rpc_msg_set_tensor_hash_rsp { uint8_t result; }; @@ -548,15 +554,12 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; rpc_tensor rpc_tensor = serialize_tensor(tensor); if (size > HASH_THRESHOLD) { - // input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) - size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + sizeof(uint64_t); - std::vector input(input_size, 0); - uint64_t hash = fnv_hash((const uint8_t*)data, size); - memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); - memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); - memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &hash, sizeof(hash)); + rpc_msg_set_tensor_hash_req request; + request.tensor = rpc_tensor; + request.offset = offset; + request.hash = fnv_hash((const uint8_t*)data, size); rpc_msg_set_tensor_hash_rsp response; - bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, input.data(), input.size(), &response, sizeof(response)); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response)); GGML_ASSERT(status); if (response.result) { // the server has the same data, no need to send it @@ -864,7 +867,7 @@ class rpc_server { bool free_buffer(const rpc_msg_free_buffer_req & request); bool buffer_clear(const rpc_msg_buffer_clear_req & request); bool set_tensor(const std::vector & input); - bool set_tensor_hash(const std::vector & input, rpc_msg_set_tensor_hash_rsp & response); + bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response); bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response); bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response); bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response); @@ -1101,18 +1104,10 @@ bool rpc_server::get_cached_file(uint64_t hash, std::vector & data) { return true; } -bool rpc_server::set_tensor_hash(const std::vector & input, rpc_msg_set_tensor_hash_rsp & response) +bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response) { - // serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) | - if (input.size() != sizeof(rpc_tensor) + 16) { - return false; - } - const rpc_tensor * in_tensor = (const rpc_tensor *)input.data(); - uint64_t offset; - memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset)); - const uint64_t * hash = (const uint64_t *)(input.data() + sizeof(rpc_tensor) + sizeof(offset)); std::vector cached_file; - if (!get_cached_file(*hash, cached_file)) { + if (!get_cached_file(request.hash, cached_file)) { response.result = 0; return true; } @@ -1125,25 +1120,28 @@ bool rpc_server::set_tensor_hash(const std::vector & input, rpc_msg_set ggml_context_ptr ctx_ptr { ggml_init(params) }; GGML_ASSERT(ctx_ptr != nullptr); ggml_context * ctx = ctx_ptr.get(); - ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); if (tensor == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); return false; } - GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size, *hash); + GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", + __func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash); // sanitize tensor->data { const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer); const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer); - if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) { + if (request.tensor.data + request.offset < p0 + || request.tensor.data + request.offset >= p1 + || size > (p1 - request.tensor.data - request.offset)) { GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n", - __func__, in_tensor->data, offset, size, *hash, p0, p1); + __func__, request.tensor.data, request.offset, size, request.hash, p0, p1); return false; } } - ggml_backend_tensor_set(tensor, cached_file.data(), offset, size); + ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size); response.result = 1; return true; } @@ -1503,12 +1501,12 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, break; } case RPC_CMD_SET_TENSOR_HASH: { - std::vector input; - if (!recv_msg(sockfd, input)) { + rpc_msg_set_tensor_hash_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { return; } rpc_msg_set_tensor_hash_rsp response; - if (!server.set_tensor_hash(input, response)) { + if (!server.set_tensor_hash(request, response)) { return; } if (!send_msg(sockfd, &response, sizeof(response))) { From f8c75dc43e10f4c7d747cf907bc448f03511edcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 9 May 2025 12:14:04 +0200 Subject: [PATCH 04/21] CUDA: fix crash on large batch size for MoE models (llama/13384) --- ggml/src/ggml-cuda/getrows.cu | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index ea8bf691609..963e4d03dd7 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -10,10 +10,11 @@ static __global__ void k_get_rows( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; - const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; - const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2; + const int i10 = blockIdx.x; + const int i11 = blockIdx.z / ne12; + const int i12 = blockIdx.z % ne12; if (i00 >= ne00) { return; @@ -46,10 +47,11 @@ static __global__ void k_get_rows_float( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - const int i00 = blockIdx.x*blockDim.x + threadIdx.x; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; - const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; - const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i00 = blockIdx.y * blockDim.x + threadIdx.x; + const int i10 = blockIdx.x; + const int i11 = blockIdx.z / ne12; + const int i12 = blockIdx.z % ne12; if (i00 >= ne00) { return; @@ -94,8 +96,8 @@ static void get_rows_cuda_q( const size_t nb1, const size_t nb2, const size_t nb3, cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); - const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); - const dim3 block_nums(block_num_x, ne10, ne11*ne12); + const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); + const dim3 block_nums(ne10, block_num_y, ne11*ne12); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); @@ -127,8 +129,8 @@ static void get_rows_cuda_float( const size_t nb1, const size_t nb2, const size_t nb3, cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); - const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; - const dim3 block_nums(block_num_x, ne10, ne11*ne12); + const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; + const dim3 block_nums(ne10, block_num_y, ne11*ne12); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); From aef59f485107f1f7461757e740ec352618f293d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 9 May 2025 13:34:58 +0200 Subject: [PATCH 05/21] CUDA: FA support for Deepseek (Ampere or newer) (llama/13306) * CUDA: FA support for Deepseek (Ampere or newer) * do loop unrolling via C++ template --- ggml/src/ggml-cuda/CMakeLists.txt | 2 +- ggml/src/ggml-cuda/common.cuh | 19 + ggml/src/ggml-cuda/cp-async.cuh | 11 + ggml/src/ggml-cuda/fattn-common.cuh | 26 +- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 908 +++++++++++------- ggml/src/ggml-cuda/fattn-tile-f16.cu | 4 +- ggml/src/ggml-cuda/fattn-tile-f32.cu | 4 +- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 2 +- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 2 +- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 2 +- ggml/src/ggml-cuda/fattn.cu | 116 ++- ggml/src/ggml-cuda/ggml-cuda.cu | 12 +- ...ttn-mma-f16-instance-ncols1_1-ncols2_16.cu | 5 + ...attn-mma-f16-instance-ncols1_1-ncols2_8.cu | 12 +- ...ttn-mma-f16-instance-ncols1_16-ncols2_1.cu | 12 +- ...ttn-mma-f16-instance-ncols1_16-ncols2_2.cu | 12 +- ...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 12 +- ...ttn-mma-f16-instance-ncols1_2-ncols2_16.cu | 5 + ...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 12 +- ...attn-mma-f16-instance-ncols1_2-ncols2_8.cu | 12 +- ...ttn-mma-f16-instance-ncols1_32-ncols2_1.cu | 12 +- ...ttn-mma-f16-instance-ncols1_32-ncols2_2.cu | 12 +- ...ttn-mma-f16-instance-ncols1_4-ncols2_16.cu | 5 + ...attn-mma-f16-instance-ncols1_4-ncols2_2.cu | 12 +- ...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 12 +- ...attn-mma-f16-instance-ncols1_4-ncols2_8.cu | 12 +- ...ttn-mma-f16-instance-ncols1_64-ncols2_1.cu | 12 +- ...attn-mma-f16-instance-ncols1_8-ncols2_1.cu | 12 +- ...attn-mma-f16-instance-ncols1_8-ncols2_2.cu | 12 +- ...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 12 +- ...attn-mma-f16-instance-ncols1_8-ncols2_8.cu | 12 +- .../template-instances/generate_cu_files.py | 21 +- 32 files changed, 815 insertions(+), 521 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 969a178f6c3..c9ff4aa321b 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -118,7 +118,7 @@ if (CUDAToolkit_FOUND) set(CUDA_CXX_FLAGS "") - set(CUDA_FLAGS -use_fast_math) + set(CUDA_FLAGS -use_fast_math -extended-lambda) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8") # Options are: diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 919217dfaee..64fb4ff4cec 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -296,6 +296,25 @@ static __device__ void no_device_code( #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.") #endif // __CUDA_ARCH__ +// The compiler is always able to unroll loops if they contain continue expressions. +// In such cases loop unrolling can still be achieved via recursion: +template +struct ggml_cuda_unroll { + template + __device__ void operator()(const Func & f, Args... args) const { + f(n - 1, args...); + ggml_cuda_unroll{}(f, args...); + } +}; + +template <> +struct ggml_cuda_unroll<1> { + template + __device__ void operator()(const Func & f, Args... args) const { + f(0, args...); + } +}; + template static __device__ __forceinline__ int warp_reduce_sum(int x) { #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh index ecb659997ba..63d0c482ff7 100644 --- a/ggml/src/ggml-cuda/cp-async.cuh +++ b/ggml/src/ggml-cuda/cp-async.cuh @@ -2,6 +2,17 @@ #include "common.cuh" + +static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) { +#ifdef CP_ASYNC_AVAILABLE + return __cvta_generic_to_shared(generic_ptr); +#else + GGML_UNUSED(generic_ptr); + NO_DEVICE_CODE; + return 0; +#endif // CP_ASYNC_AVAILABLE +} + // Copies data from global to shared memory, cg == cache global. // Both the src and dst pointers must be aligned to 16 bit. // Shared memory uses 32 bit addressing, the pointer is passed as unsigned int. diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index c7dc728821e..b7180d5955c 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -516,7 +516,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { nullptr; } -template // D == head size +template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { @@ -665,13 +665,13 @@ static void on_no_fattn_vec_case(const int D) { fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n"); GGML_ABORT("fatal error"); } else { - fprintf(stderr, "Unsupported KV type combination for head_size 256.\n"); + fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D); fprintf(stderr, "Only f16 is supported.\n"); GGML_ABORT("fatal error"); } } -template +template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE @@ -691,7 +691,7 @@ void launch_fattn( GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && - "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); @@ -754,10 +754,13 @@ void launch_fattn( const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; const dim3 block_dim(warp_size, nwarps, 1); + int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. + CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); + dim3 blocks_num; if (stream_k) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. - const int max_blocks = 2*nsm; + const int max_blocks = max_blocks_per_sm*nsm; const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); @@ -769,14 +772,11 @@ void launch_fattn( blocks_num.y = 1; blocks_num.z = 1; - dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float)); + dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); } else { GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. - int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. - CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); - // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave: parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1); @@ -853,19 +853,19 @@ void launch_fattn( if (stream_k) { if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. - const dim3 block_dim_combine(D, 1, 1); + const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; - flash_attn_stream_k_fixup + flash_attn_stream_k_fixup <<>> ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); } } else if (parallel_blocks > 1) { - const dim3 block_dim_combine(D, 1, 1); + const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); - flash_attn_combine_results + flash_attn_combine_results <<>> (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 04804a15c9d..2b6bdc30c02 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -13,104 +13,217 @@ typedef tile<16, 16, float> tile_C_KQ_16; typedef tile<16, 4, half2> tile_C_VKQ; typedef tile<16, 8, half2> tile_C_VKQ_16; -template +// Config options for specific head sizes. +// Should not affect results, only speed/register pressure/shared memory use. +// +// nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators. +// nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory). +// Q_in_reg: whether the Q values should be kept permanently in registers. +// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading. +// nbatch_K2: number of K half2 values in direction of DKQ to load in parallel. +// nbatch_V2: number of V half2 values in direction of DV to load in parallel. +// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel. + +template +struct fattn_mma_f16_config; + +template <> +struct fattn_mma_f16_config< 64, 64> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 32; + static constexpr int nbatch_V2 = 32; + static constexpr int nbatch_combine = 32; +}; + +template <> +struct fattn_mma_f16_config< 80, 80> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 40; + static constexpr int nbatch_V2 = 40; + static constexpr int nbatch_combine = 40; +}; + +template <> +struct fattn_mma_f16_config< 96, 96> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 48; + static constexpr int nbatch_V2 = 48; + static constexpr int nbatch_combine = 48; +}; + +template <> +struct fattn_mma_f16_config<112, 112> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 56; + static constexpr int nbatch_V2 = 56; + static constexpr int nbatch_combine = 56; +}; + +template <> +struct fattn_mma_f16_config<128, 128> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 64; + static constexpr int nbatch_V2 = 64; + static constexpr int nbatch_combine = 64; +}; + +template <> +struct fattn_mma_f16_config<256, 256> { + static constexpr int nbatch_fa = 32; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 128; + static constexpr int nbatch_V2 = 128; + static constexpr int nbatch_combine = 128; +}; + +template <> +struct fattn_mma_f16_config<576, 512> { + static constexpr int nbatch_fa = 32; + static constexpr int nwarps_max = 8; + static constexpr bool Q_in_reg = false; + static constexpr int nstages_target = 1; + static constexpr int nbatch_K2 = 160; + static constexpr int nbatch_V2 = 128; + static constexpr int nbatch_combine = 128; +}; + +// ------------------------------------------------------------------------------------------------------------------ + +template static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( - const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) { - constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { - // If cp.async is available, load up to the highest power of 2 in D asynchronously: -#ifdef CP_ASYNC_AVAILABLE - static_assert(D >= 64 && D < 512, "bad D"); - constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128); + // K/V data is loaded with decreasing granularity for D for better memory bandwidth. + // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. + + if (use_cp_async) { + constexpr int preload = 64; + constexpr int h2_per_chunk = 16/sizeof(half2); + const int chunks_per_row = D2 / h2_per_chunk; - const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV); + const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); + + auto load = [&] __device__ (const int n) { + const int stride_k = WARP_SIZE >> n; + const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + return; + } - constexpr int preload = 64; - constexpr int h2_per_chunk = 16/sizeof(half2); - constexpr int chunks_per_row = k0_sync_start / h2_per_chunk; - constexpr int stride_i = WARP_SIZE / chunks_per_row; #pragma unroll - for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row); - const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk; + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - cp_async_cg_16(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k); - } -#else - constexpr int k0_sync_start = 0; -#endif // CP_ASYNC_AVAILABLE - static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start"); + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; + } - // If D is not a power of 2, the rest is loaded synchronously. - // K/V data is loaded with decreasing granularity for D for better memory bandwidth. - static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds"); #pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - if (k0_start == k0_stop || k0_stop <= k0_sync_start) { - continue; - } + cp_async_cg_16(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk); + } + } + }; + ggml_cuda_unroll<5>{}(load); + } else { + static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds"); + auto load = [&] __device__ (const int n) { + const int stride_k = WARP_SIZE >> n; + const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); + const int k0_stop = D2 - D2 % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + return; + } #pragma unroll - for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; + } #pragma unroll - for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[i*D2_padded + k] = KV[i*stride_KV + k]; + tile_KV[i*stride_tile + k] = KV[i*stride_KV + k]; + } } - } + }; + ggml_cuda_unroll<3>{}(load); } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { - static_assert(KQ_per_iter == 2*WARP_SIZE || KQ_per_iter == WARP_SIZE, "bad KQ_per_iter"); -#ifdef CP_ASYNC_AVAILABLE - constexpr int preload = KQ_per_iter * sizeof(half); - constexpr int cols_per_warp = 8*WARP_SIZE/KQ_per_iter; - constexpr int stride_j = nwarps * cols_per_warp; + static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter"); + + if (use_cp_async) { + constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64; + constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; + constexpr int stride_j = nwarps * cols_per_warp; - const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask); + const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask); #pragma unroll - for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + - (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/8)); + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + + (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp)); - if (j0 + stride_j > ncols1 && j >= ncols1) { - break; - } + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } - const int i = 4 * (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x % (WARP_SIZE/4) : threadIdx.x % (WARP_SIZE/8)); + const int i = 4 * (threadIdx.x % (nbatch_fa/8)); - cp_async_cg_16(tile_mask_32 + j*(KQ_per_iter*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + cp_async_cg_16(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + } + return; } -#else - constexpr int cols_per_warp = 2*WARP_SIZE/KQ_per_iter; + + constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; #pragma unroll for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + (KQ_per_iter == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/2)); + const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp)); if (j0 + stride_j > ncols1 && j >= ncols1) { break; } - const int i = KQ_per_iter == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/2); + const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp); - tile_mask[j*(KQ_per_iter/2 + 4) + i] = mask_h2[j*stride_mask + i]; + tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i]; } -#endif // CP_ASYNC_AVAILABLE } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -123,9 +236,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float logit_softcap, const int ne01, const int ne02, - const int stride_KV, + const int stride_K, + const int stride_V, const int stride_mask, const int jt, + half2 * const __restrict__ tile_Q, half2 * const __restrict__ tile_K, half2 * const __restrict__ tile_V, half2 * const __restrict__ tile_mask, @@ -135,59 +250,107 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( float * const __restrict__ KQ_rowsum, const int kb0) { #ifdef NEW_MMA_AVAILABLE + typedef fattn_mma_f16_config c; + +#ifdef CP_ASYNC_AVAILABLE + constexpr int nstages = c::nstages_target; +#else + constexpr int nstages = 0; +#endif // CP_ASYNC_AVAILABLE + constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. - constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. - const int k_VKQ_0 = kb0 * KQ_per_iter; - tile_C_KQ KQ_C[KQ_per_iter/(np*tile_C_KQ::I) * ntiles]; + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = c::nbatch_K2 + 4; + constexpr int stride_tile_V = c::nbatch_V2 + 4; + + const int k_VKQ_0 = kb0 * c::nbatch_fa; + tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; // Use wide variants of tiles if ntiles >= 2. tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; -#ifdef CP_ASYNC_AVAILABLE - cp_async_wait_all(); - __syncthreads(); - flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); -#else - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); + if constexpr (nstages > 1) { + static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); + constexpr bool use_cp_async = true; + cp_async_wait_all(); + __syncthreads(); + flash_attn_ext_f16_load_tile + (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V); + } else { + constexpr bool use_cp_async = nstages == 1; + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); + } } - flash_attn_ext_f16_load_tile(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV); - __syncthreads(); -#endif // CP_ASYNC_AVAILABLE - // Calculate tile of KQ: #pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < KQ_per_iter; i_KQ_00 += np*tile_A::I) { - const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) { + const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2; + const int k0_diff = k0_stop - k0_start; + + if (nstages <= 1) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile + (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); + } + + // Calculate tile of KQ: + if constexpr (c::Q_in_reg) { #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) { - tile_A K_A; - load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded); - if (ntiles == 1) { - mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); - } else { + for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); + if (ntiles == 1) { + mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); + } + } + } + } + } else { + static_assert(ntiles == 2, "ntiles != 2 not implemented"); +#pragma unroll + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { + load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); + +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); + // Wide version of KQ_C is column-major => swap A and B. - mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); + mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A); } } } - } -#ifndef CP_ASYNC_AVAILABLE - __syncthreads(); // Only needed if tile_K == tile_V. -#endif // CP_ASYNC_AVAILABLE + if (nstages <= 1) { + __syncthreads(); // Only needed if tile_K == tile_V. + } + } if (use_logit_softcap) { - static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int i = 0; i < KQ_per_iter/(np*tile_C_KQ::I) * ntiles; ++i) { + for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); @@ -205,7 +368,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if (ntiles == 1) { if (ncols2 > 1 || mask_h2) { #pragma unroll - for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ::I) { + for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) { const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { @@ -213,16 +376,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * - __half2float(((const half *) tile_mask)[j*(KQ_per_iter + 8) + i]); + __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]); } } } // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); @@ -238,10 +401,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); - + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); @@ -252,7 +414,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } else { // ntiles > 1 if (ncols2 > 1 || mask_h2) { #pragma unroll - for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ_16::J) { + for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) { const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; #pragma unroll for (int t = 0; t < ntiles/2; ++t) { @@ -261,7 +423,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; - const float2 tmp = __half22float2(tile_mask[j*(KQ_per_iter/2 + 4) + i]); + const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]); const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; @@ -272,9 +434,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { #pragma unroll for (int t = 0; t < ntiles/2; ++t) { #pragma unroll @@ -294,9 +456,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - static_assert(KQ_per_iter % (np*tile_C_KQ_16::J) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { #pragma unroll for (int t = 0; t < ntiles/2; ++t) { #pragma unroll @@ -325,7 +487,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if (ntiles == 1) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll - for (int i = 0; i < D/tile_C_VKQ::I; ++i) { + for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { #pragma unroll for (int l = 0; l < tile_C_VKQ::ne; ++l) { VKQ_C[i].x[l] *= KQ_max_scale_h2; @@ -336,7 +498,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); #pragma unroll - for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) { + for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { #pragma unroll for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; @@ -347,16 +509,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } // Convert KQ C tiles into B tiles for VKQ calculation: - tile_B B[KQ_per_iter/(np*2*tile_B::J) * ntiles]; + tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles]; tile_B_16 * B_16 = (tile_B_16 *) B; - static_assert(KQ_per_iter % (np*2*tile_B::J) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size"); if (ntiles == 1) { #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*2*tile_B::J); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) { B[k] = get_transposed(get_half2(KQ_C[k])); } } else { - for (int k = 0; k < KQ_per_iter/(np*2*tile_B_16::J); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) { #pragma unroll for (int t = 0; t < ntiles/2; ++t) { B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); @@ -364,52 +526,67 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } -#ifdef CP_ASYNC_AVAILABLE - // Preload K tile for next iteration: - cp_async_wait_all(); - __syncthreads(); - if (!last_iter) { - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask(mask_h2 + (k_VKQ_0 + KQ_per_iter)/2, tile_mask, stride_mask); + if (nstages > 1) { + // Preload K tile for next iteration: + constexpr bool use_cp_async = true; + cp_async_wait_all(); + __syncthreads(); + if (!last_iter) { + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask + (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile + (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K); } - flash_attn_ext_f16_load_tile(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV); } -#else - flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); - __syncthreads(); -#endif // CP_ASYNC_AVAILABLE - // Calculate VKQ tile: #pragma unroll - for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) { - static_assert((KQ_per_iter/2) % (np*tile_A::J) == 0, "bad loop size"); + for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) { + const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV; + const int i0_diff = i0_stop - i0_start; + + if (nstages == 1) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile + (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); + } + + // Calculate VKQ tile: +#pragma unroll + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) { + static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size"); #pragma unroll - for (int k00 = 0; k00 < KQ_per_iter/2; k00 += np*tile_A::J) { - const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) { + const int k0 = k00 + (threadIdx.y % np)*tile_A::J; - tile_A A; - load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); - if (ntiles == 1) { - mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); - } else { + tile_A A; + load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + if (ntiles == 1) { + mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + } else { #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - // Wide version of VKQ_C is column-major => swap A and B. - mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of VKQ_C is column-major => swap A and B. + mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); + } } } } - } - -#ifndef CP_ASYNC_AVAILABLE - __syncthreads(); // Only needed if tile_K == tile_V. -#endif // CP_ASYNC_AVAILABLE + if (nstages <= 1) { + __syncthreads(); // Only needed if tile_K == tile_V. + } + } #else GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); @@ -419,7 +596,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // NEW_MMA_AVAILABLE } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -434,7 +611,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int ne02, const int stride_Q1, const int stride_Q2, - const int stride_KV, + const int stride_K, + const int stride_V, const int stride_mask, const int jt, const int kb0_start, @@ -442,6 +620,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #ifdef NEW_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + typedef fattn_mma_f16_config c; + +#ifdef CP_ASYNC_AVAILABLE + constexpr int nstages = c::nstages_target; +#else + constexpr int nstages = 0; +#endif // CP_ASYNC_AVAILABLE + constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; @@ -449,22 +635,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); - static_assert(D % nwarps == 0, "bad D"); - static_assert(KQ_per_iter % nwarps == 0, "bad KQ_per_iter"); + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = c::nbatch_K2 + 4; + constexpr int stride_tile_V = c::nbatch_V2 + 4; - constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; - // Temporary shared buffer for loading K/V data with KQ_per_iter*D logical elements: - extern __shared__ half2 tile_K[]; -#ifdef CP_ASYNC_AVAILABLE - half2 * tile_V = tile_K + KQ_per_iter*D2_padded; -#else - half2 * tile_V = tile_K; -#endif // CP_ASYNC_AVAILABLE - half2 * tile_mask = tile_V + KQ_per_iter*D2_padded; + extern __shared__ half2 tile_Q[]; + half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; + half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K; + half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max; - tile_B Q_B[D/(2*tile_B::J) * ntiles]; - tile_C_VKQ VKQ_C[D/tile_C_VKQ::I * ntiles]; + tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles]; + tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles]; tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; @@ -476,13 +659,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( KQ_max[col] = -FLT_MAX/2.0f; } - // Temporarily load Q data into tile_K, will be loaded into registers afterwards. + // Load Q data into tile_Q, either temporarily or permanently. + // Q in registers is faster, but register pressure is the biggest bottleneck. // The loading is done with decreasing granularity for D for better memory bandwidth. const half2 scale_h2 = make_half2(scale, scale); #pragma unroll for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k); + const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k); const int stride_jc = WARP_SIZE / stride_k; if (k0_start == k0_stop) { @@ -506,14 +690,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; - tile_K[jc*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); + tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y); } } else { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_K[jc*D2_padded + k] = make_half2(0.0f, 0.0f); + tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f); } } } @@ -521,18 +705,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); - { + if (c::Q_in_reg) { const int j0 = (threadIdx.y / np) * cols_per_warp; #pragma unroll - for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) { if (ntiles == 1) { - load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded); + load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); } else { #pragma unroll for (int t = 0; t < ntiles/2; ++t) { load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], - tile_K + (j0 + t*tile_B_16::I)*D2_padded + k0, D2_padded); + tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q); } } } @@ -540,35 +724,37 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); - // Preload mask and K data for first iteration when using cp_async: -#ifdef CP_ASYNC_AVAILABLE - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask(mask_h2 + kb0_start*KQ_per_iter/2, tile_mask, stride_mask); + // Preload mask and K data for first iteration when using cp_async with multiple stages: + if constexpr (nstages > 1) { + static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); + constexpr bool use_cp_async = true; + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask + (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile + (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K); } - flash_attn_ext_f16_load_tile(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV); -#endif // CP_ASYNC_AVAILABLE // Iterate over ne11 == previous tokens: for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. constexpr bool last_iter = true; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); } - // With cp_async there is no __syncthreads at the end of the iter, + // With multi-stage loading there is no __syncthreads at the end of the iter, // there can be a race condition on shared memory access for combining/writing back results. -#ifdef CP_ASYNC_AVAILABLE - if (nwarps*cols_per_warp > KQ_per_iter) { + if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) { __syncthreads(); } -#endif // CP_ASYNC_AVAILABLE // Finally, sum up partial KQ rowsums. // The partial sums are spread across 8/4 threads each, does not need full reduce. @@ -584,38 +770,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } - // Write VKQ accumulators to shared memory in column-major format. - // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. - // Also for np > 1 the combination is done via these values in shared memory. - if (ntiles == 1) { - const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data -#pragma unroll - for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { - const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. + // Combine VKQ accumulator values if np > 1. + // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. + // So also write VKQ accumulators to shared memory in column-major format if np == 1. -#pragma unroll - for (int l = 0; l < tile_B::ne; ++l) { - const int k = k0 + tile_B::get_j(l); - - tile_K[jc_cwd*D2_padded + k] = B.x[l]; - } - } - } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; -#pragma unroll - for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) { -#pragma unroll - for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { - const int j = j0 + tile_C_VKQ_16::get_i(l); - const int k = k0 + tile_C_VKQ_16::get_j(l); - - tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; - } - } - } - } + constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4; + constexpr int tile_stride = nbatch_combine + 4; + static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); if constexpr (ntiles == 1) { const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset @@ -624,7 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. - ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } __syncthreads(); @@ -649,7 +810,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. - ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } __syncthreads(); @@ -676,11 +837,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); - float2 * const meta_ptr = ((float2 *) tile_K) + jc_meta*(D2_padded/2) + D/4; + float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2; float2 meta[nmeta]; #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { - meta[imeta] = meta_ptr[imeta * WARP_SIZE * D2_padded/2]; + meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2]; } float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. @@ -690,10 +851,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset >= WARP_SIZE) { - continue; + if (offset < WARP_SIZE) { + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); } - KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); } float KQ_cms[nmeta]; // KQ combine max scale per warp. @@ -709,10 +869,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset >= WARP_SIZE) { - continue; + if (offset < WARP_SIZE) { + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } - KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } // Write back combined meta data: @@ -720,7 +879,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( for (int imeta = 0; imeta < nmeta; ++imeta) { if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { // Combined KQ max scale + rowsum. - meta_ptr[imeta * WARP_SIZE * D2_padded/2] = make_float2(KQ_cms[imeta], KQ_crs); + meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); } } @@ -736,88 +895,118 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } - if (np > 1) { - __syncthreads(); - } - - if (np == 1 || threadIdx.y % np == 0) { - // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. - // The values after that are for the partial results of the individual blocks. - float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2)); +#pragma unroll + for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { + if (ntiles == 1) { + const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data +#pragma unroll + for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) { + const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. #pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); - const int stride_jc = WARP_SIZE / stride_k; + for (int l = 0; l < tile_B::ne; ++l) { + const int k = k0 + tile_B::get_j(l); - if (k0_start == k0_stop) { - continue; + tile_Q[jc_cwd*tile_stride + k] = B.x[l]; + } } - + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; +#pragma unroll + for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) { #pragma unroll - for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { - const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { + const int j = j0 + tile_C_VKQ_16::get_i(l); + const int k = k0 + tile_C_VKQ_16::get_j(l); - if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { - break; + tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; + } } + } + } - const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; + __syncthreads(); + + if (np == 1 || threadIdx.y % np == 0) { + // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. + // The values after that are for the partial results of the individual blocks. + float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2)); - const int j_dst = jc_dst / ncols2; - const int c_dst = jc_dst % ncols2; +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k); + const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; - if (!is_fixup && jt*ncols1 + j_dst >= ne01) { + if (k0_start == k0_stop) { continue; } - const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2; #pragma unroll - for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { + const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - float2 dstk_val = make_float2(0.0f, 0.0f); -#pragma unroll - for (int ip = 0; ip < np; ++ip) { - const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0]; - const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]); - dstk_val.x += dstk_val_add.x*KQ_crs; - dstk_val.y += dstk_val_add.y*KQ_crs; + if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { + break; } - if (!needs_fixup && !is_fixup) { - const float KQ_rowsum_j = meta_j[1]; - dstk_val.x /= KQ_rowsum_j; - dstk_val.y /= KQ_rowsum_j; + const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; + + const int j_dst = jc_dst / ncols2; + const int c_dst = jc_dst % ncols2; + + if (!is_fixup && jt*ncols1 + j_dst >= ne01) { + continue; } - if (is_fixup) { - dstk_fixup_data[jc_dst*(D/2) + k] = dstk_val; - } else { - dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(D/2) + k] = dstk_val; + const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine; +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + float2 dstk_val = make_float2(0.0f, 0.0f); +#pragma unroll + for (int ip = 0; ip < np; ++ip) { + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0]; + const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]); + dstk_val.x += dstk_val_add.x*KQ_crs; + dstk_val.y += dstk_val_add.y*KQ_crs; + } + + if (!needs_fixup && !is_fixup) { + const float KQ_rowsum_j = meta_j[1]; + dstk_val.x /= KQ_rowsum_j; + dstk_val.y /= KQ_rowsum_j; + } + + if (is_fixup) { + dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val; + } else { + dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val; + } } } } } - } - - if (np > 1) { - __syncthreads(); + if (np > 1) { + __syncthreads(); + } } #else GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); - GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask); + GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } -template -__launch_bounds__(nwarps*WARP_SIZE, 2) +template +__launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -857,24 +1046,27 @@ static __global__ void flash_attn_ext_f16( #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(D == 128 || D == 256)) { + if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { NO_DEVICE_CODE; return; } - static_assert(FATTN_KQ_STRIDE % KQ_per_iter == 0, "bad KQ_per_iter"); + typedef fattn_mma_f16_config c; + + static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config::nbatch_fa == 0, "bad nbatch_fa"); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int stride_Q1 = nb01 / sizeof(float2); const int stride_Q2 = nb02 / sizeof(float2); - const int stride_KV = nb11 / sizeof(half2); + const int stride_K = nb11 / sizeof(half2); + const int stride_V = nb21 / sizeof(half2); const int stride_mask = nb31 / sizeof(half2); const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - constexpr int kb_niter = FATTN_KQ_STRIDE / KQ_per_iter; // Number of kernel iterations per assigned KQ slice. + constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. // kbc == k block continuous, current index in continuous ijk space. int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; @@ -893,9 +1085,9 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; @@ -905,14 +1097,14 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } kbc += iter_k; @@ -931,9 +1123,9 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; @@ -942,9 +1134,9 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); @@ -960,28 +1152,42 @@ static __global__ void flash_attn_ext_f16( #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) } -template +template void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + + typedef fattn_mma_f16_config c; + + constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2; + constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2; + constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine; + + const int nstages = cp_async_available(cc) ? c::nstages_target : 0; + constexpr int ncols = ncols1 * ncols2; - constexpr int KQ_per_iter = D <= 128 && ncols1 <= 64 ? 64 : 32; - constexpr int nwarps = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4; - constexpr int ntiles = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4); + constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp. constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int nwarps_max_x = ncols / cols_per_warp; + constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; + constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; - static_assert(D % tile_B::J == 0, "bad D"); + static_assert(DKQ % tile_B::J == 0, "bad DKQ"); + static_assert(DV % tile_A::J == 0, "bad DV"); static_assert(ncols % cols_per_warp == 0, "bad ncols"); - const ggml_tensor * KQV = dst; - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; + const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); + const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); + const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); - const int KQ_shared_rows = cp_async_available(cc) ? 2*KQ_per_iter : KQ_per_iter; + const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; - const size_t nbytes_shared_KV = KQ_shared_rows * (D + 8) * sizeof(half); - const size_t nbytes_shared_mask = ncols1 * (KQ_per_iter + 8) * sizeof(half); - const size_t nbytes_shared_combine = nwarps*cols_per_warp * (D + 8) * sizeof(half); - - const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine); + const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ? + std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) : + nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask); float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); @@ -989,59 +1195,73 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; + +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; + +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) } - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true); } -#define DECL_FATTN_MMA_F16_CASE(D, ncols1, ncols2) \ - template void ggml_cuda_flash_attn_ext_mma_f16_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ - -#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(D, ncols) \ - extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/1, 1); \ - extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/2, 2); \ - extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \ - extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \ - -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8) - -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16) - -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32) - -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64) - -// Kernels with ncols == 128 are only 4% faster due to register pressure. -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory. +#define DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2) \ + template void ggml_cuda_flash_attn_ext_mma_f16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(DKQ, DV, ncols) \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 1, 1); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 2, 2); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 4, 4); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 8, 8); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/16, 16); \ + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 8) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 16) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 32) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) + +// The number of viable configurations for Deepseek is very limited: +extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index e0039e1755c..9283560d5c4 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -307,7 +307,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); } break; case 128: { @@ -315,7 +315,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); } break; default: { diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index fcb6f848fe0..32673adb57f 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -318,7 +318,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); } break; case 128: { @@ -326,7 +326,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); } break; default: { diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index e17d2d0e4fb..ef0addc1dbc 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -315,7 +315,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); } template diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index d42ddca49f6..7064675d5ab 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); } template diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index bc21b27a0cc..c5668adb152 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -490,7 +490,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm fattn_kernel = flash_attn_ext_f16< D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>; } - launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size); + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size); } void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 7a2d1e45365..9c5c803d02b 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -8,58 +8,32 @@ #include "fattn-wmma-f16.cuh" #include "fattn.cuh" -template +template static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; - if (Q->ne[1] <= 8/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); - return; + if constexpr (ncols2 <= 8) { + if (Q->ne[1] <= 8/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } } if (Q->ne[1] <= 16/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } if (Q->ne[1] <= 32/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); -} - -template -static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst); - break; - default: - GGML_ABORT("fatal error"); - break; - } + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); } -static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -68,27 +42,79 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); - const float use_gqa_opt = mask && max_bias == 0.0f; + const bool use_gqa_opt = mask && max_bias == 0.0f; GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; if (use_gqa_opt && gqa_ratio % 8 == 0) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio == 4) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst); + if (use_gqa_opt && gqa_ratio % 4 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio == 2) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst); + if (use_gqa_opt && gqa_ratio % 2 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); +} + +static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + switch (Q->ne[0]) { + case 64: + GGML_ASSERT(V->ne[0] == 64); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst); + break; + case 80: + GGML_ASSERT(V->ne[0] == 80); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst); + break; + case 96: + GGML_ASSERT(V->ne[0] == 96); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst); + break; + case 112: + GGML_ASSERT(V->ne[0] == 112); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst); + break; + case 128: + GGML_ASSERT(V->ne[0] == 128); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst); + break; + case 256: + GGML_ASSERT(V->ne[0] == 256); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); + break; + case 576: { + // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. + GGML_ASSERT(V->ne[0] == 512); + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const bool use_gqa_opt = mask && max_bias == 0.0f; + GGML_ASSERT(use_gqa_opt); + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(gqa_ratio % 16 == 0); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + } break; + default: + GGML_ABORT("fatal error"); + break; + } } #define FATTN_VEC_F16_CASE(D, type_K, type_V) \ @@ -299,7 +325,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion; - const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0; + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { if (prec == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 42302e4ecc6..7643c4b8bfa 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3215,16 +3215,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return false; #endif // FLASH_ATTN_AVAILABLE if (op->src[1]->ne[0] != op->src[2]->ne[0]) { - // different head sizes of K and V are not supported yet - return false; + const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; + if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) { + return false; + } + const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2]; + return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0; } if (op->src[0]->ne[0] == 192) { return false; } - if (op->src[0]->ne[0] == 576) { - // DeepSeek MLA - return false; - } if (op->src[0]->ne[3] != 1) { return false; } diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu new file mode 100644 index 00000000000..fb26abeb0da --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu index 80108615ac8..dc16829021f 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 1, 8); -DECL_FATTN_MMA_F16_CASE(80, 1, 8); -DECL_FATTN_MMA_F16_CASE(96, 1, 8); -DECL_FATTN_MMA_F16_CASE(112, 1, 8); -DECL_FATTN_MMA_F16_CASE(128, 1, 8); -DECL_FATTN_MMA_F16_CASE(256, 1, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 1, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu index 66161c0abeb..9d3cfd8edf7 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 16, 1); -DECL_FATTN_MMA_F16_CASE(80, 16, 1); -DECL_FATTN_MMA_F16_CASE(96, 16, 1); -DECL_FATTN_MMA_F16_CASE(112, 16, 1); -DECL_FATTN_MMA_F16_CASE(128, 16, 1); -DECL_FATTN_MMA_F16_CASE(256, 16, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu index ee88c72aa04..2e1883af40e 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 16, 2); -DECL_FATTN_MMA_F16_CASE(80, 16, 2); -DECL_FATTN_MMA_F16_CASE(96, 16, 2); -DECL_FATTN_MMA_F16_CASE(112, 16, 2); -DECL_FATTN_MMA_F16_CASE(128, 16, 2); -DECL_FATTN_MMA_F16_CASE(256, 16, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu index d888a5a423a..2074e954a32 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 16, 4); -DECL_FATTN_MMA_F16_CASE(80, 16, 4); -DECL_FATTN_MMA_F16_CASE(96, 16, 4); -DECL_FATTN_MMA_F16_CASE(112, 16, 4); -DECL_FATTN_MMA_F16_CASE(128, 16, 4); -DECL_FATTN_MMA_F16_CASE(256, 16, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu new file mode 100644 index 00000000000..f011a208cd2 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu index d93a2d08ed7..24c64cf000f 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 2, 4); -DECL_FATTN_MMA_F16_CASE(80, 2, 4); -DECL_FATTN_MMA_F16_CASE(96, 2, 4); -DECL_FATTN_MMA_F16_CASE(112, 2, 4); -DECL_FATTN_MMA_F16_CASE(128, 2, 4); -DECL_FATTN_MMA_F16_CASE(256, 2, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 2, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 2, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu index 617464c9456..163b1d939e4 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 2, 8); -DECL_FATTN_MMA_F16_CASE(80, 2, 8); -DECL_FATTN_MMA_F16_CASE(96, 2, 8); -DECL_FATTN_MMA_F16_CASE(112, 2, 8); -DECL_FATTN_MMA_F16_CASE(128, 2, 8); -DECL_FATTN_MMA_F16_CASE(256, 2, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 2, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu index 970d2b68696..0543532ea34 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 32, 1); -DECL_FATTN_MMA_F16_CASE(80, 32, 1); -DECL_FATTN_MMA_F16_CASE(96, 32, 1); -DECL_FATTN_MMA_F16_CASE(112, 32, 1); -DECL_FATTN_MMA_F16_CASE(128, 32, 1); -DECL_FATTN_MMA_F16_CASE(256, 32, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 32, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 32, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 32, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 32, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 32, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 32, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu index 65cd377c395..407b6cf4c70 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 32, 2); -DECL_FATTN_MMA_F16_CASE(80, 32, 2); -DECL_FATTN_MMA_F16_CASE(96, 32, 2); -DECL_FATTN_MMA_F16_CASE(112, 32, 2); -DECL_FATTN_MMA_F16_CASE(128, 32, 2); -DECL_FATTN_MMA_F16_CASE(256, 32, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 32, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 32, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu new file mode 100644 index 00000000000..f5fd0e2369c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu index f4a8bf34899..5e46685024b 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 4, 2); -DECL_FATTN_MMA_F16_CASE(80, 4, 2); -DECL_FATTN_MMA_F16_CASE(96, 4, 2); -DECL_FATTN_MMA_F16_CASE(112, 4, 2); -DECL_FATTN_MMA_F16_CASE(128, 4, 2); -DECL_FATTN_MMA_F16_CASE(256, 4, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu index de191a8ab66..1ada657f194 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 4, 4); -DECL_FATTN_MMA_F16_CASE(80, 4, 4); -DECL_FATTN_MMA_F16_CASE(96, 4, 4); -DECL_FATTN_MMA_F16_CASE(112, 4, 4); -DECL_FATTN_MMA_F16_CASE(128, 4, 4); -DECL_FATTN_MMA_F16_CASE(256, 4, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu index e8cb0e1b312..bad296b4141 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 4, 8); -DECL_FATTN_MMA_F16_CASE(80, 4, 8); -DECL_FATTN_MMA_F16_CASE(96, 4, 8); -DECL_FATTN_MMA_F16_CASE(112, 4, 8); -DECL_FATTN_MMA_F16_CASE(128, 4, 8); -DECL_FATTN_MMA_F16_CASE(256, 4, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu index a532e96296b..0d7a9c72853 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 1); -DECL_FATTN_MMA_F16_CASE(80, 64, 1); -DECL_FATTN_MMA_F16_CASE(96, 64, 1); -DECL_FATTN_MMA_F16_CASE(112, 64, 1); -DECL_FATTN_MMA_F16_CASE(128, 64, 1); -DECL_FATTN_MMA_F16_CASE(256, 64, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 64, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 64, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 64, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 64, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 64, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 64, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu index bf25181aa76..9d5a9976f0e 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 8, 1); -DECL_FATTN_MMA_F16_CASE(80, 8, 1); -DECL_FATTN_MMA_F16_CASE(96, 8, 1); -DECL_FATTN_MMA_F16_CASE(112, 8, 1); -DECL_FATTN_MMA_F16_CASE(128, 8, 1); -DECL_FATTN_MMA_F16_CASE(256, 8, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu index 378c132e658..a6e6f093dcb 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 8, 2); -DECL_FATTN_MMA_F16_CASE(80, 8, 2); -DECL_FATTN_MMA_F16_CASE(96, 8, 2); -DECL_FATTN_MMA_F16_CASE(112, 8, 2); -DECL_FATTN_MMA_F16_CASE(128, 8, 2); -DECL_FATTN_MMA_F16_CASE(256, 8, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu index 372641be9a0..86d4ffae27c 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 8, 4); -DECL_FATTN_MMA_F16_CASE(80, 8, 4); -DECL_FATTN_MMA_F16_CASE(96, 8, 4); -DECL_FATTN_MMA_F16_CASE(112, 8, 4); -DECL_FATTN_MMA_F16_CASE(128, 8, 4); -DECL_FATTN_MMA_F16_CASE(256, 8, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu index 9ff5968b6ab..680a13ca6de 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 8, 8); -DECL_FATTN_MMA_F16_CASE(80, 8, 8); -DECL_FATTN_MMA_F16_CASE(96, 8, 8); -DECL_FATTN_MMA_F16_CASE(112, 8, 8); -DECL_FATTN_MMA_F16_CASE(128, 8, 8); -DECL_FATTN_MMA_F16_CASE(256, 8, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index dd373a09d26..3428113dc8f 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -18,7 +18,7 @@ """ -SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n" +SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n" TYPES_MMQ = [ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", @@ -57,18 +57,21 @@ def get_head_sizes(type_k, type_v): with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v)) -for ncols in [8, 16, 32, 64, 128]: - for ncols2 in [1, 2, 4, 8]: +for ncols in [8, 16, 32, 64]: + for ncols2 in [1, 2, 4, 8, 16]: + if ncols2 > ncols: + continue ncols1 = ncols // ncols2 - if ncols == 128: - continue # Too much register pressure. with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f: f.write(SOURCE_FATTN_MMA_START) - for head_size in [64, 80, 96, 112, 128, 256]: - if ncols == 128 and head_size == 256: - continue # Needs too much shared memory. - f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size)) + for head_size_kq in [64, 80, 96, 112, 128, 256, 576]: + if head_size_kq != 576 and ncols2 == 16: + continue + if head_size_kq == 576 and ncols2 != 16: + continue + head_size_v = head_size_kq if head_size_kq != 576 else 512 + f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) for type in TYPES_MMQ: with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: From b493e03b90ac24f00cfa21feb0f0548ac5745930 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alberto=20Cabrera=20P=C3=A9rez?= Date: Fri, 9 May 2025 16:34:08 +0100 Subject: [PATCH 06/21] sycl : implementation of reordered Q4_0 MMVQ for Intel GPUs (llama/12858) * sycl : Implemented reorder Q4_0 mmvq Signed-off-by: Alberto Cabrera * sycl : Fixed mmvq being called when reorder is disabled * sycl : Improved comments in the quants header Signed-off-by: Alberto Cabrera * Use static_assert * safe_div -> ceil_div * Clarify qi comment * change the reorder tensor from init to execute OP * dbg * Undo changes to test-backend-ops * Refactor changes on top of q4_0 reorder fix * Missing Reverts * Refactored opt_for_reorder logic to simplify code path * Explicit inlining and unroll * Renamed mul_mat_algo enum for consistency --------- Signed-off-by: Alberto Cabrera Co-authored-by: romain.biessy --- ggml/src/ggml-sycl/backend.hpp | 17 ++- ggml/src/ggml-sycl/common.hpp | 1 + ggml/src/ggml-sycl/ggml-sycl.cpp | 120 +++++++++++---- ggml/src/ggml-sycl/mmvq.cpp | 247 ++++++++++++++++++++----------- ggml/src/ggml-sycl/quants.hpp | 61 ++++++++ ggml/src/ggml-sycl/vecdotq.hpp | 71 +++++++-- 6 files changed, 383 insertions(+), 134 deletions(-) create mode 100644 ggml/src/ggml-sycl/quants.hpp diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index de814ef91a0..f78a36ddf8f 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -14,23 +14,24 @@ #define GGML_SYCL_BACKEND_HPP #include "binbcast.hpp" -#include "concat.hpp" #include "common.hpp" +#include "concat.hpp" #include "conv.hpp" #include "convert.hpp" +#include "cpy.hpp" #include "dequantize.hpp" #include "dmmv.hpp" +#include "element_wise.hpp" +#include "gla.hpp" +#include "im2col.hpp" #include "mmq.hpp" #include "mmvq.hpp" -#include "rope.hpp" #include "norm.hpp" +#include "outprod.hpp" +#include "quants.hpp" +#include "rope.hpp" #include "softmax.hpp" #include "tsembd.hpp" -#include "im2col.hpp" #include "wkv.hpp" -#include "outprod.hpp" -#include "element_wise.hpp" -#include "cpy.hpp" -#include "gla.hpp" -#endif // GGML_SYCL_BACKEND_HPP +#endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 69aad938e88..60909dde7d0 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -42,6 +42,7 @@ void ggml_sycl_host_free(void* ptr); extern int g_ggml_sycl_debug; extern int g_ggml_sycl_disable_optimize; +extern int g_ggml_sycl_prioritize_dmmv; #define GGML_SYCL_DEBUG(...) \ do { \ diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 68a26fa481d..0ea729948ec 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -49,6 +49,7 @@ static bool g_sycl_loaded = false; int g_ggml_sycl_debug = 0; int g_ggml_sycl_disable_optimize = 0; int g_ggml_sycl_disable_graph = 0; +int g_ggml_sycl_prioritize_dmmv = 0; static ggml_sycl_device_info ggml_sycl_init() { ggml_sycl_device_info info = {}; @@ -195,11 +196,13 @@ static void ggml_check_sycl() try { g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1); g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); + g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize); GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph); + GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); GGML_LOG_INFO("Build with Macros:\n"); #if defined(GGML_SYCL_FORCE_MMQ) GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n"); @@ -2822,12 +2825,45 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons std::exit(1); } +enum class mul_mat_algo { + DMMV = 0, + MMVQ = 1, + MUL_MAT_SYCL = 2, +}; + inline bool ggml_sycl_supports_mmq(enum ggml_type type) { // TODO: accuracy issues in MMQ GGML_UNUSED(type); return false; } +inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return true; + default: + return false; + } +} + +inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return true; + default: + return false; + } +} + +inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return true; + default: + return false; + } +} + static bool ggml_sycl_supports_dmmv(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: @@ -2856,7 +2892,7 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows, GGML_ASSERT((size % sizeof(block_q4_0) == 0)); GGML_ASSERT((offset % sizeof(block_q4_0) == 0)); int offset_blks = offset / sizeof(block_q4_0); - auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;; + auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2; auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks; stream->parallel_for( @@ -2884,25 +2920,44 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { reorder_qw(data_device, ncols, nrows, size, 0, stream); } -/* -* This function could be called when the OP (mul_mat) function support reorder optimizition. -*/ -static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, - ggml_tensor * dst) { - if (!g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT - ctx->opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf. - dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases. - src0->type == GGML_TYPE_Q4_0 && - src1->ne[2]==1 && src1->ne[3]==1) { +static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) { + return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT + ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf. + dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases. + dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1; +} - ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra; - if (!extra) return; //only happen in CI/UT permute case. +static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */, + ggml_tensor * dst, mul_mat_algo mm_algorithm) { + if (!should_reorder_tensor(*ctx, dst)) { + return; + } - if (extra->optimized_feature.reorder) return; //skip the tensor which is handled for reorder. + ggml_tensor_extra_gpu * extra = static_cast(src0->extra); + if (!extra || extra->optimized_feature.reorder) { + return; // Skip permutations and already reordered tensors + } - reorder_qw(src0, ctx->stream()); - extra->optimized_feature.reorder = true; //used to decode/dequan in next steps. + switch (mm_algorithm) { + case mul_mat_algo::DMMV: + if (!ggml_sycl_supports_reorder_dmmv(src0->type)) { + return; + } + break; + case mul_mat_algo::MMVQ: + if (!ggml_sycl_supports_reorder_mmvq(src0->type)) { + return; + } + break; + case mul_mat_algo::MUL_MAT_SYCL: + if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) { + return; + } + break; } + + reorder_qw(src0, ctx->stream()); + extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering } static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -2911,7 +2966,8 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor int64_t min_compute_capability = INT_MAX; if (split) { - ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context; + ggml_backend_sycl_split_buffer_type_context * buft_ctx = + (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context; auto & tensor_split = buft_ctx->tensor_split; for (int id = 0; id < ggml_sycl_info().device_count; ++id) { // skip devices that are not going to do any work: @@ -2924,7 +2980,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor } } } else { - min_compute_capability = ggml_sycl_info().devices[ctx.device].cc; + min_compute_capability = ggml_sycl_info().devices[ctx.device].cc; } // check data types and tensor shapes for custom matrix multiplication kernels: @@ -2946,9 +3002,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE); #endif // SYCL_USE_XMX + // mmvq path is faster in the CUDA backend. - if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda) + if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda + // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization + // is enabled takes precedence over DMMV, the current if-else implementation + // requires disabling DMMV if both conditions are met + || (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) { use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; + } if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { // TODO: Refactor and cleanup of mul mat dispatching. @@ -2967,17 +3029,23 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor // KQ + KQV multi-batch ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { - opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder. - ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); - // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream()); + constexpr bool convert_src1_to_q8_1 = false; + opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1); } else if (use_mul_mat_vec_q) { - ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); + constexpr bool convert_src1_to_q8_1 = true; + opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1); } else if (use_mul_mat_q) { - ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true); + constexpr bool convert_src1_to_q8_1 = true; + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1); } else { - opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder. - ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); + constexpr bool convert_src1_to_q8_1 = false; + // MUL_MAT_SYCL supports reorder + opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MUL_MAT_SYCL); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1); } + GGML_SYCL_DEBUG("call %s done\n", __func__); } diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 1b92ba2d604..3cade1a42a6 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -1,6 +1,60 @@ #include "mmvq.hpp" + +#include "ggml.h" +#include "common.hpp" +#include "quants.hpp" #include "vecdotq.hpp" -#include + +template +static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols, const int nrows, const sycl::nd_item<3> & nd_item) { + using block_type = ggml_sycl_reordered::block_q_t; + using block_traits = typename block_type::traits; + + const auto sg = nd_item.get_sub_group(); + const int sg_range = sg.get_group_linear_range(); + const int workgroup_id = nd_item.get_group_linear_id(); + const int sg_id = sg.get_group_linear_id(); + const int row = workgroup_id * sg_range + sg_id; + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / block_traits::qk; + constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi); + constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq; + + static_assert(blocks_per_subgroup > 0); + static_assert(block_elements_per_subgroup > 0); + + const block_q8_1 * y = (const block_q8_1 *) vy; + + float partial_sum = 0.0f; + for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) { + const int ibx = row * blocks_per_row + i; // x block index + // TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits + const int bx_offset = block_type::get_block_offset(ibx); + const int d_offset = block_type::get_d_offset(nrows, ncols, ibx); + + // Y block index that aligns with ibx + const int iby = i * block_type::block_to_q8_1_ratio(); + +#pragma unroll + for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) { + // x block quant index when casting the quants to int + const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup); + + partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs); + } + } + + auto sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum, std::plus<>()); + + if (sg.leader()) { + dst[row] = sum; + } +} template static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, @@ -480,26 +534,39 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx, } } -static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, +static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, nrows, + nd_item); + }); + }); +} + +static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK4_0 == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); - { - - stream->submit([&](sycl::handler &cgh) { - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1); - }); + { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -916,93 +983,95 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy, } } -void ggml_sycl_op_mul_mat_vec_q( - ggml_backend_sycl_context & ctx, - const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, - const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, - float *dst_dd_i, const int64_t row_low, const int64_t row_high, - const int64_t src1_ncols, const int64_t src1_padded_col_size, - const dpct::queue_ptr &stream) { - +void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, + ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, + const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_col_size, + const dpct::queue_ptr & stream) { const int64_t ne10 = src1->ne[0]; GGML_ASSERT(ne10 % QK8_1 == 0); - const int64_t ne00 = src0->ne[0]; + const int64_t ne00 = src0->ne[0]; const int64_t row_diff = row_high - row_low; int id; - SYCL_CHECK( - CHECK_TRY_ERROR(id = get_current_device_id())); + SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id())); const size_t q8_1_ts = sizeof(block_q8_1); const size_t q8_1_bs = QK8_1; // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the kernel writes into - for (int i = 0; i < src1_ncols; i++) - { + for (int i = 0; i < src1_ncols; i++) { const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs; - const char* src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset; - float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0]; + const char * src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset; + float * dst_dd_i_bs = dst_dd_i + i * dst->ne[0]; switch (src0->type) { - case GGML_TYPE_Q4_0: - mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q6_K: - mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ1_S: - mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ1_M: - mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ2_XXS: - mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ2_XS: - mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ2_S: - mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ3_XXS: - mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ3_S: - mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ4_NL: - mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ4_XS: - mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - default: - GGML_ABORT("fatal error"); + case GGML_TYPE_Q4_0: + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n"); + reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } else { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n"); + mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + break; + case GGML_TYPE_Q4_1: + mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q2_K: + mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q3_K: + mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_K: + mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_K: + mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q6_K: + mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ1_S: + mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ1_M: + mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_XXS: + mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_S: + mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ4_NL: + mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ4_XS: + mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + default: + GGML_ABORT("fatal error"); } } GGML_UNUSED(src1); diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp new file mode 100644 index 00000000000..a74e30526c1 --- /dev/null +++ b/ggml/src/ggml-sycl/quants.hpp @@ -0,0 +1,61 @@ +// +// MIT license +// Copyright (C) 2025 Codeplay Software Ltd. +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_QUANTS_HPP +#define GGML_SYCL_QUANTS_HPP + +#include "ggml-common.h" +#include "ggml.h" + +namespace ggml_sycl_reordered { + + +// The reordered block moves quants (qs) and scales(d) to two +// uniform regions of memory that is contiguous in the same tensor. +// What this means is that instead of having: +// [d0, qs0] [d1, qs1] [d2, qs2] ... [dN, qsN] +// We have: +// [qs0, qs1, qs2, ..., qsN] [d0, d1, d2, ..., dN] +// +// Notes: out-of-bounds qs will run into d values +// Aligment relies on the allocated size of qs + +template struct block_q_t; + + +// qk number of weights / quants in a block +// qr number of weights in a byte (described as 'before dequantization') +// for quantization types that has low and high bits split, qr is calculated with +// using the lower bits, e.g for Q6 quants QR6 is 2 +// qi number of 32 bit integers needed to represent all the quants from a block (`qs` field) +// See ggml-common.h to see how these are calculated +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK4_0; + static constexpr uint32_t qi = QI4_0; + static constexpr uint32_t qr = QR4_0; + static constexpr uint32_t vdr_mmvq = 2; + }; + + static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); } + + static constexpr int get_d_offset(int nrows, int ncols, const int block_index) { + return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half); + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } +}; + +} // namespace ggml_sycl_reordered + +#endif // GGML_SYCL_QUANTS_HPP diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index c5942008adf..cbf664fcf28 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -1,6 +1,6 @@ // // MIT license -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: MIT // @@ -14,8 +14,11 @@ #define GGML_SYCL_VECDOTQ_HPP #include "dpct/helper.hpp" +#include "ggml.h" +#include "quants.hpp" -typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); +typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, + const int & iqs); static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) { const uint16_t* x16 = @@ -252,13 +255,60 @@ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh, // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q +template struct reorder_vec_dot_q_sycl { + static_assert(T != T, "ggml_type for reorder vecdot not implemented"); +}; + +template <> struct reorder_vec_dot_q_sycl { + static constexpr ggml_type gtype = GGML_TYPE_Q4_0; + + using q4_0_block = ggml_sycl_reordered::block_q_t; + using q4_0_traits = typename q4_0_block::traits; + + __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4, const sycl::half2 & ds8) { + int sumi = 0; + +#pragma unroll + for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi); + sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi); + } + + const sycl::float2 ds8f = ds8.convert(); + + // second part effectively subtracts 8 from each quant value + return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y()); + } + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, + const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const uint8_t * bq4_0 = static_cast(vbq) + ibx_offset; + const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset)); + int v[q4_0_traits::vdr_mmvq]; + int u[2 * q4_0_traits::vdr_mmvq]; + +#pragma unroll + + for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) { + v[i] = get_int_from_uint8(bq4_0, iqs + i); + u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + q4_0_traits::qi); + } + + return vec_dot_q4_0_q8_1_impl(v, u, d, bq8_1->ds); + }; +}; + #define VDR_Q4_0_Q8_1_MMVQ 2 #define VDR_Q4_0_Q8_1_MMQ 4 template -static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u, - const float &d4, - const sycl::half2 &ds8) { +static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4, + const sycl::half2 & ds8) { int sumi = 0; #pragma unroll for (int i = 0; i < vdr; ++i) { @@ -270,8 +320,7 @@ static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u, sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi); } - const sycl::float2 ds8f = - ds8.convert(); + const sycl::float2 ds8f = ds8.convert(); // second part effectively subtracts 8 from each quant value return d4 * (sumi * ds8f.x() - (8 * vdr / QI4_0) * ds8f.y()); @@ -456,13 +505,13 @@ vec_dot_q4_0_q8_1(const void *__restrict__ vbq, const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; int v[VDR_Q4_0_Q8_1_MMVQ]; - int u[2*VDR_Q4_0_Q8_1_MMVQ]; + int u[2 * VDR_Q4_0_Q8_1_MMVQ]; #pragma unroll for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) { - v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); - u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); + v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); + u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); } return vec_dot_q4_0_q8_1_impl(v, u, bq4_0->d, bq8_1->ds); From 22f4997dd8170a9e93d81e9dd222132a939d91c9 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 9 May 2025 23:07:07 -0700 Subject: [PATCH 07/21] vulkan: scalar flash attention implementation (llama/13324) * vulkan: scalar flash attention implementation * vulkan: always use fp32 for scalar flash attention * vulkan: use vector loads in scalar flash attention shader * vulkan: remove PV matrix, helps with register usage * vulkan: reduce register usage in scalar FA, but perf may be slightly worse * vulkan: load each Q value once. optimize O reduction. more tuning * vulkan: support q4_0/q8_0 KV in scalar FA * CI: increase timeout to accommodate newly-supported tests * vulkan: for scalar FA, select between 1 and 8 rows * vulkan: avoid using Float16 capability in scalar FA --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 243 +++++---- .../vulkan-shaders/flash_attn.comp | 483 ++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 12 +- 3 files changed, 645 insertions(+), 93 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 2dc2883a70c..e2b357fdc15 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -275,6 +275,7 @@ struct vk_device_struct { bool prefer_host_memory; bool float_controls_rte_fp16; bool subgroup_add; + bool subgroup_shuffle; bool integer_dot_product; @@ -402,12 +403,20 @@ struct vk_device_struct { vk_pipeline pipeline_conv2d_dw_cwhn_f32; // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} + vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_split_k_reduce; std::unordered_map pipelines; @@ -1581,13 +1590,29 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; -static std::array fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { +static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; +static constexpr uint32_t scalar_flash_attention_num_large_rows = 8; + +static uint32_t get_fa_num_small_rows(bool scalar) { + return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows; +} + +static std::array fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { GGML_UNUSED(clamp); + if (scalar) { + if (small_rows) { + return {scalar_flash_attention_num_small_rows, 64}; + } else { + return {scalar_flash_attention_num_large_rows, 32}; + } + } + // small rows, large cols if (small_rows) { - return {flash_attention_num_small_rows, 64}; + return {get_fa_num_small_rows(scalar), 32}; } + // small cols to reduce register count if (ggml_is_quantized(type) || D == 256) { return {64, 32}; @@ -1882,65 +1907,66 @@ static void ggml_vk_load_shaders(vk_device& device) { parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; -#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - if (device->coopmat2) { - - auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { - return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1}; - }; + auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { + return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1}; + }; - auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { - // For large number of rows, 128 invocations seems to work best. - // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we - // can't use 256 for D==80. - uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128; - auto rows_cols = fa_rows_cols(D, clamp, type, small_rows); - // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads - GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); - return {wg_size, rows_cols[0], rows_cols[1], (D), clamp}; - }; + auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + // For large number of rows, 128 invocations seems to work best. + // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we + // can't use 256 for D==80. + // For scalar, use 128 (arbitrary) + uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128); + auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows); + + // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. + // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. + const uint32_t D_lsb = D ^ (D & (D-1)); + uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); + + // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads + GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); + return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split}; + }; -#define CREATE_FA2(TYPE, NAMELC, D) \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ - -#define CREATE_FA(TYPE, NAMELC) \ - CREATE_FA2(TYPE, NAMELC, 64) \ - CREATE_FA2(TYPE, NAMELC, 80) \ - CREATE_FA2(TYPE, NAMELC, 96) \ - CREATE_FA2(TYPE, NAMELC, 112) \ - CREATE_FA2(TYPE, NAMELC, 128) \ - CREATE_FA2(TYPE, NAMELC, 256) - - CREATE_FA(GGML_TYPE_F16, f16) - CREATE_FA(GGML_TYPE_Q4_0, q4_0) - CREATE_FA(GGML_TYPE_Q4_1, q4_1) - CREATE_FA(GGML_TYPE_Q5_0, q5_0) - CREATE_FA(GGML_TYPE_Q5_1, q5_1) - CREATE_FA(GGML_TYPE_Q8_0, q8_0) - // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently - //CREATE_FA(GGML_TYPE_Q2_K, q2_k) - //CREATE_FA(GGML_TYPE_Q3_K, q3_k) - //CREATE_FA(GGML_TYPE_Q4_K, q4_k) - //CREATE_FA(GGML_TYPE_Q5_K, q5_k) - //CREATE_FA(GGML_TYPE_Q6_K, q6_k) - //CREATE_FA(GGML_TYPE_IQ1_S, iq1_s) - //CREATE_FA(GGML_TYPE_IQ1_M, iq1_m) - //CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs) - //CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs) - //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s) - //CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs) - //CREATE_FA(GGML_TYPE_IQ3_S, iq3_s) - //CREATE_FA(GGML_TYPE_IQ4_XS, iq4_xs) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl) +#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ + +#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256) + + CREATE_FA(GGML_TYPE_F16, f16, true, ) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, ) +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + CREATE_FA(GGML_TYPE_F16, f16, false, _cm2) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2) + } +#endif +#undef CREATE_FA2 #undef CREATE_FA +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ @@ -2837,6 +2863,9 @@ static vk_device ggml_vk_get_device(size_t idx) { device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); + device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); + const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; @@ -5709,20 +5738,57 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx assert(q->type == GGML_TYPE_F32); assert(k->type == v->type); + bool scalar = !ctx->device->coopmat2; + + uint32_t gqa_ratio = 1; + uint32_t qk_ratio = neq2 / nek2; + uint32_t workgroups_x = (uint32_t)neq1; + uint32_t workgroups_y = (uint32_t)neq2; + uint32_t workgroups_z = (uint32_t)neq3; + + // For scalar FA, we can use the "large" size to accommodate qga. + // For coopmat FA, we always use the small size (which is still pretty large for gqa). + const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false); + + if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && + qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { + // grouped query attention - make the N dimension equal to gqa_ratio, reduce + // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 + // and change addressing calculations to index Q's dimension 2. + gqa_ratio = qk_ratio; + N = gqa_ratio; + workgroups_y /= N; + } + vk_pipeline *pipelines; // XXX TODO other backends may be changing accumulator precision to default to f32 soon - bool f32acc = dst->op_params[3] == GGML_PREC_F32; - bool small_rows = N <= flash_attention_num_small_rows; - switch (D) { - case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; - case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; - case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; - case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; - case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; - case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; - default: - assert(!"unsupported D value"); - return; + bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32; + bool small_rows = N <= get_fa_num_small_rows(scalar); + + if (scalar) { + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; + default: + GGML_ASSERT(!"unsupported D value"); + return; + } + } else { + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break; + default: + GGML_ASSERT(!"unsupported D value"); + return; + } } assert(pipelines); @@ -5740,27 +5806,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_pipeline pipeline = pipelines[aligned]; assert(pipeline); - uint32_t gqa_ratio = 1; - uint32_t qk_ratio = neq2 / nek2; - uint32_t workgroups_x = (uint32_t)neq1; - uint32_t workgroups_y = (uint32_t)neq2; - uint32_t workgroups_z = (uint32_t)neq3; - - if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows && - qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { - // grouped query attention - make the N dimension equal to gqa_ratio, reduce - // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 - // and change addressing calculations to index Q's dimension 2. - gqa_ratio = qk_ratio; - N = gqa_ratio; - workgroups_y /= N; - } - uint32_t split_kv = KV; uint32_t split_k = 1; + // Use a placeholder core count if one isn't available. split_k is a big help for perf. + const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; + // Try to use split_k when KV is large enough to be worth the overhead - if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) { + if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) { // Try to run two workgroups per SM. split_k = ctx->device->shader_core_count * 2 / workgroups_y; if (split_k > 1) { @@ -9530,9 +9583,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_FLASH_ATTN_EXT: { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - if (!ggml_vk_get_device(ctx->device)->coopmat2) { - return false; - } + auto device = ggml_vk_get_device(ctx->device); + bool coopmat2 = device->coopmat2; switch (op->src[0]->ne[0]) { case 64: case 80: @@ -9540,7 +9592,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case 112: case 128: case 256: - case 575: // DeepSeek MLA break; default: return false; @@ -9566,10 +9617,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (op->src[1]->type) { case GGML_TYPE_F16: case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + // supported in scalar and coopmat2 paths + break; case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently //case GGML_TYPE_Q2_K: //case GGML_TYPE_Q3_K: @@ -9585,10 +9638,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm //case GGML_TYPE_IQ3_S: //case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + // currently supported only in coopmat2 path + if (!coopmat2) { + return false; + } break; default: return false; } + if (!coopmat2 && !device->subgroup_shuffle) { + // scalar FA uses subgroupShuffle + return false; + } return true; } case GGML_OP_GET_ROWS: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp new file mode 100644 index 00000000000..e6545160d53 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -0,0 +1,483 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#extension GL_KHR_shader_subgroup_shuffle : enable + +#include "types.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t D = 32; + +layout (constant_id = 5) const uint32_t D_split = 16; +const uint32_t D_per_thread = D / D_split; + +const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split; +const uint32_t cols_per_thread = Bc / cols_per_iter; + +layout (push_constant) uniform parameter { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; +} p; + +layout (binding = 0) readonly buffer Q {float data_q[];}; +layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; +layout (binding = 1) readonly buffer K {float16_t data_k[];}; +layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; +layout (binding = 2) readonly buffer V {float16_t data_v[];}; +layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 3) readonly buffer M {float16_t data_m[];}; +layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; + +#if defined(A_TYPE_PACKED16) +#define BINDING_IDX_K 0 +#define BINDING_IDX_V 1 +layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; +#endif + +#if defined(DATA_A_Q4_0) +#define BLOCK_BYTE_SIZE 18 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); +} +#endif + +#if defined(DATA_A_Q8_0) +#define BLOCK_BYTE_SIZE 34 +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * D + c; + data_o[o_offset + offset] = D_TYPE(elem); + return elem; +} + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + +shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; +shared vec4 tmpshv4[gl_WorkGroupSize.x]; + +shared float masksh[Bc][Br]; +shared vec4 Qf[Br][D / 4]; + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + const uint32_t tid = gl_LocalInvocationIndex; + const uint32_t N = p.N; + const uint32_t KV = p.KV; + + const uint32_t d_tid = gl_LocalInvocationIndex % D_split; + const uint32_t col_tid = gl_LocalInvocationIndex / D_split; + + uint32_t i = gl_WorkGroupID.x; + uint32_t split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + + const uint32_t Tr = CEIL_DIV(N, Br); + + const uint32_t start_j = split_k_index * p.split_kv / Bc; + const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); + + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; + const uint32_t iq3 = gl_WorkGroupID.z; + + // broadcast factors + const uint32_t rk2 = p.neq2/p.nek2; + const uint32_t rk3 = p.neq3/p.nek3; + + const uint32_t rv2 = p.neq2/p.nev2; + const uint32_t rv3 = p.neq3/p.nev3; + + // k indices + const uint32_t ik3 = iq3 / rk3; + const uint32_t ik2 = iq2 / rk2; + + // v indices + const uint32_t iv3 = iq3 / rv3; + const uint32_t iv2 = iq2 / rv2; + + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; + uint32_t k_stride = p.nb11; + uint32_t v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; + + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + + [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (D / 4); + uint32_t r = (idx + tid) / (D / 4); + if (r < Br && d < D / 4 && + i * Br + r < N) { + Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; + } + } + barrier(); + + vec4 Of[Br][D_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] = vec4(0.0); + } + } + + float Lf[Br], Mf[Br]; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Lf[r] = 0; + Mf[r] = NEG_FLT_MAX_OVER_2; + } + + float slope[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + slope[r] = 1.0; + } + + // ALiBi + if (p.max_bias > 0.0f) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + } + } + +#if BLOCK_SIZE > 1 + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; +#else + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; +#endif + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + float Sf[Br][cols_per_thread]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = 0.0; + } + } + + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); +#else + vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf); + } + } + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + // Compute sum across the D_split + [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Sf[r][c] += subgroupShuffleXor(Sf[r][c], s); + } + } + } + + if (p.logit_softcap != 0.0f) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]); + } + } + } + + if (p.mask != 0) { + + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br) { + masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]); + } + } + barrier(); + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float mvf = masksh[c * cols_per_iter + col_tid][r]; + + Sf[r][c] += slope[r]*mvf; + } + } + barrier(); + } + + float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + rowmaxf[r] = Sf[r][0]; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowmaxf[r] = max(rowmaxf[r], Sf[r][c]); + } + Moldf[r] = Mf[r]; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf[r], Moldf[r]); + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Pf[r][c] = exp(Sf[r][c] - Mf[r]); + } + eMf[r] = exp(Moldf[r] - Mf[r]); + + // Compute sum across row of P + rowsumf[r] = 0.0; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowsumf[r] += Pf[r][c]; + } + + Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] = eMf[r] * Of[r][d]; + } + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] += Pf[r][c] * Vf; + } + } + } + + barrier(); + } + + // reduce across threads + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float rowmaxf, eMf; + + tmpsh[tid] = Mf[r]; + // Compute max across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]); + } + barrier(); + } + rowmaxf = tmpsh[d_tid]; + barrier(); + + float Moldf = Mf[r]; + + // M = max(rowmax, Mold) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf, Moldf); + eMf = exp(Moldf - Mf[r]); + + Lf[r] = eMf*Lf[r]; + + tmpsh[tid] = Lf[r]; + + // Compute sum across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s]; + } + barrier(); + } + Lf[r] = tmpsh[d_tid]; + barrier(); + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + + Of[r][d] = eMf * Of[r][d]; + tmpshv4[tid] = Of[r][d]; + + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + Of[r][d] += tmpshv4[tid + s]; + tmpshv4[tid] = Of[r][d]; + } + barrier(); + } + Of[r][d] = tmpshv4[d_tid]; + barrier(); + } + } + + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + uint32_t o_offset = D * p.ne1 * split_k_index; + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + } + } + } + } + + o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + + return; + } + + float Lfrcp[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Lfrcp[r] = 1.0 / Lf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] *= Lfrcp[r]; + } + } + + uint32_t o_offset = iq3*p.ne2*p.ne1; + + if (p.gqa_ratio > 1) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + } + } + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (i * Br + r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + } + } + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 382f190f287..d196137eb29 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -421,7 +421,6 @@ void process_shaders() { #endif } -#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) // flash attention for (const auto& f16acc : {false, true}) { std::string acctype = f16acc ? "float16_t" : "float"; @@ -432,6 +431,7 @@ void process_shaders() { } if (tname == "bf16") continue; +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); @@ -440,9 +440,17 @@ void process_shaders() { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); } +#endif + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); + } } } -#endif for (const auto& tname : type_names) { // mul mat vec From 04445664b4c5f5db7d273e4cff9465f4829d97a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 10 May 2025 09:16:52 +0200 Subject: [PATCH 08/21] CUDA: fix FlashAttention on Turing (llama/13415) --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 2b6bdc30c02..b2f95fa3f00 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -546,7 +546,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV; const int i0_diff = i0_stop - i0_start; - if (nstages == 1) { + if (nstages <= 1) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); From 86dece9c7cb5f46976334bb1312835535d7ae981 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 10 May 2025 22:22:48 +0200 Subject: [PATCH 09/21] CUDA: fix race conditions FlashAttention kernels (llama/13438) --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 2 ++ ggml/src/ggml-cuda/fattn-vec-f16.cuh | 1 + 2 files changed, 3 insertions(+) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index b2f95fa3f00..9873ea755a5 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -874,6 +874,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } + __syncthreads(); + // Write back combined meta data: #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index ef0addc1dbc..d96e3921298 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -168,6 +168,7 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { KQ[j*D + tid] = -HALF_MAX_HALF; } + __syncthreads(); half2 VKQ[ncols] = {{0.0f, 0.0f}}; From 0b1962a181dbfdb3ee9a9ba94d178ff633d6b2f3 Mon Sep 17 00:00:00 2001 From: David Huang <1969802+hjc4869@users.noreply.github.com> Date: Sun, 11 May 2025 20:18:39 +0800 Subject: [PATCH 10/21] Add `--no-op-offload` to improve `-ot` pp perf in MoE models like llama4 400B (llama/13386) --- ggml/include/ggml-backend.h | 4 ++-- ggml/src/ggml-backend.cpp | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index ea2c1a402cc..778927f6821 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -248,7 +248,7 @@ extern "C" { // preferrably to run on the same backend as the buffer ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); - sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false); + sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false, true); // initialize buffers from a max size graph (optional) reserve_graph = build_graph(sched, max_batch_size); @@ -289,7 +289,7 @@ extern "C" { typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); // Initialize a backend scheduler, backends with low index are given priority over backends with high index - GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel); + GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload); GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); // Initialize backend buffers from a measure graph diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index c36b5abfb74..6f69d895f17 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -674,6 +674,8 @@ struct ggml_backend_sched { char * context_buffer; size_t context_buffer_size; + bool op_offload; + int debug; }; @@ -766,7 +768,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor); // check if a backend with higher prio wants to offload the op - if (src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) { + if (sched->op_offload && src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) { for (int b = 0; b < src_backend_id; b++) { if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) { SET_CAUSE(tensor, "1.off"); @@ -1452,7 +1454,8 @@ ggml_backend_sched_t ggml_backend_sched_new( ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, - bool parallel) { + bool parallel, + bool op_offload) { GGML_ASSERT(n_backends > 0); GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS); GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU); @@ -1497,6 +1500,7 @@ ggml_backend_sched_t ggml_backend_sched_new( } sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends); + sched->op_offload = op_offload; ggml_backend_sched_reset(sched); From c426829771dd731acbe2a329fedc6291c2bd447d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 11 May 2025 16:09:33 +0200 Subject: [PATCH 11/21] CUDA: fix crash with partial offloading of MoE (llama/13439) --- ggml/src/ggml-cuda/ggml-cuda.cu | 10 ++++++++-- ggml/src/ggml-cuda/mmq.cu | 4 ++-- ggml/src/ggml-cuda/mmvq.cu | 4 ++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 7643c4b8bfa..b4b85abcda9 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1909,13 +1909,19 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); + // If src0 is a temporary compute buffer it may have some padding that needs to be cleared for mul_mat_vec_q or mul_mat_q. + // But if src0 is also a view of another tensor then this cannot be done safely because it may overwrite valid tensor data. + // Therefore, in such cases use cuBLAS. + const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE + && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src; + bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; - bool use_mul_mat_q = ggml_is_quantized(src0->type) + bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; bool any_gpus_with_slow_fp16 = false; diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index b4962e6a51c..e1cf843de1a 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -91,11 +91,11 @@ void ggml_cuda_mul_mat_q( // If src0 is a temporary compute buffer, clear any potential padding. if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { - GGML_ASSERT(ggml_is_contiguously_allocated(src0)); - GGML_ASSERT(!src0->view_src); const size_t size_data = ggml_nbytes(src0); const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0); if (size_alloc > size_data) { + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); + GGML_ASSERT(!src0->view_src); CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream)); } } diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 3b313ea2953..dc7adf509fa 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -515,11 +515,11 @@ void ggml_cuda_mul_mat_vec_q( // If src0 is a temporary compute buffer, clear any potential padding. if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { - GGML_ASSERT(ggml_is_contiguously_allocated(src0)); - GGML_ASSERT(!src0->view_src); const size_t size_data = ggml_nbytes(src0); const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0); if (size_alloc > size_data) { + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); + GGML_ASSERT(!src0->view_src); CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream)); } } From 882d97572977773f993839a150fc4f40b0d9fe17 Mon Sep 17 00:00:00 2001 From: Atharva Dubey Date: Mon, 12 May 2025 06:15:32 +0100 Subject: [PATCH 12/21] enable dpcpp nightly builds with libraries (llama/13406) --- ggml/src/ggml-sycl/CMakeLists.txt | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 6699b70bad0..231fb71dab5 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -52,9 +52,8 @@ target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing") find_package(DNNL) set(GGML_SYCL_DNNL 0) if(DNNL_FOUND) - if (DEFINED ENV{ONEAPI_ROOT} AND NOT DEFINED DNNL_GPU_VENDOR) - # Assuming oneDNN packaged with oneapi release is used which - # supports only intel target + if (NOT DEFINED DNNL_GPU_VENDOR) + # default to intel target set(DNNL_GPU_VENDOR "INTEL") if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL") message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target") @@ -108,6 +107,9 @@ endif() if (GGML_SYCL_TARGET STREQUAL "INTEL") # Intel devices use Intel oneMKL directly instead of oneMath to avoid the limitation of linking Intel oneMKL statically # See https://github.com/uxlfoundation/oneMath/issues/654 + if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + set(SYCL_COMPILER ON) + endif() find_package(MKL REQUIRED) target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS) target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_USE_INTEL_ONEMKL) From 8264872b5d9f62b2b8e34e3eb1be80d4c89e8279 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 12 May 2025 10:51:21 +0200 Subject: [PATCH 13/21] CUDA: fix misaligned synchronization in FA (llama/13469) --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 9873ea755a5..491780abd40 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -895,6 +895,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } + } else if (np > 1) { + // Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch. + // Therefore, all other warps also need to execute a __syncthreads(). + // Otherwise the points at which warps synchronize with each other would become misaligned. + __syncthreads(); } #pragma unroll From cb90cb09925e196755350119a0d747105bed1aa7 Mon Sep 17 00:00:00 2001 From: Dan Johansson Date: Mon, 12 May 2025 13:06:19 +0200 Subject: [PATCH 14/21] ggml-cpu: Integrate fp32=bf16xbf16 SME KleidiAI kernel (llama/13053) * ggml-cpu: Integrate fp32=bf16xbf16 SME KleidiAI kernel Signed-off-by: Dan Johansson * * code review fixes Signed-off-by: Dan Johansson * * adds a comment that clarifies barrier usage Signed-off-by: Dan Johansson --------- Signed-off-by: Dan Johansson Co-authored-by: Charles Xu --- ggml/src/ggml-cpu/CMakeLists.txt | 29 ++- ggml/src/ggml-cpu/kleidiai/kernels.cpp | 93 ++++++- ggml/src/ggml-cpu/kleidiai/kernels.h | 58 ++++- ggml/src/ggml-cpu/kleidiai/kleidiai.cpp | 331 +++++++++++++++++++----- 4 files changed, 414 insertions(+), 97 deletions(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 9a3085befc4..bdaec2881dd 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -428,6 +428,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/) set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}") @@ -438,17 +439,19 @@ function(ggml_add_cpu_backend_variant_impl tag_name) string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED) string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED) - set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS}) + set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP}) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c) + list(APPEND GGML_KLEIDIAI_SOURCES + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c) if (NOT DOTPROD_ENABLED MATCHES -1) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c) + list(APPEND GGML_KLEIDIAI_SOURCES + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c) endif() if (NOT I8MM_ENABLED MATCHES -1) @@ -456,9 +459,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() if (NOT SME_ENABLED MATCHES -1) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c) - set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2") + list(APPEND GGML_KLEIDIAI_SOURCES + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c) + set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2") endif() set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}") diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index aacc2bb5ee0..910fd0ee4e7 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -4,16 +4,22 @@ // KleidiAI micro-kernels #include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" -#include "kai_lhs_quant_pack_qsi8d32p_f32.h" -#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" -#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" -#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" +#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" + +#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" +#include "kai_lhs_quant_pack_qsi8d32p_f32.h" +#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" + +#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" +#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" +#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" + #include "kai_common.h" #include "kernels.h" @@ -61,6 +67,53 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, }, /* .required_cpu = */ CPU_FEATURE_SME, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, + }, + { + /* SME GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + }, + /* SME GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, + /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, + /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, + }, + /* .required_cpu = */ CPU_FEATURE_SME, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_F16, + /* .op_type = */ GGML_TYPE_F32, }, #endif #if defined(__APPLE__) @@ -105,6 +158,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, }, #endif #if defined(__ARM_FEATURE_MATMUL_INT8) @@ -148,6 +204,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, }, #endif #else @@ -192,6 +251,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, }, #endif #if defined(__ARM_FEATURE_DOTPROD) @@ -235,12 +297,33 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, }, #endif #endif }; -ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) { +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) { + ggml_kleidiai_kernels * kernel = nullptr; + + if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { + if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && + gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type && + gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type && + gemm_gemv_kernels[i].op_type == tensor->type) { + kernel = &gemm_gemv_kernels[i]; + break; + } + } + } + + return kernel; +} + +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) { ggml_kleidiai_kernels * kernels = nullptr; for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h index 2ffe97eb42f..5ac02bda7c0 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.h +++ b/ggml/src/ggml-cpu/kleidiai/kernels.h @@ -4,6 +4,9 @@ #pragma once +#include +#include "ggml.h" + enum cpu_feature { CPU_FEATURE_NONE = 0, CPU_FEATURE_DOTPROD = 1, @@ -26,26 +29,53 @@ struct kernel_info { size_t (*get_nr)(void); size_t (*get_kr)(void); size_t (*get_sr)(void); - size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl); - size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl); + std::variant< + std::function, + std::function + > get_lhs_offset; + std::variant< + std::function, + std::function + > get_rhs_packed_offset; size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride); size_t (*get_dst_size)(size_t m, size_t n); - void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, - float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); + std::variant< + std::function, + std::function + > run_kernel; }; struct lhs_packing_info { size_t (*get_offset)(size_t m_idx, size_t lhs_stride); - size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); - size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); - void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, - size_t lhs_stride, void* lhs_packed); + std::variant< + std::function, + std::function + > get_packed_offset; + std::variant< + std::function, + std::function + > packed_size; + std::variant< + std::function, + std::function + > pack_func; }; struct rhs_packing_info { - size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl); - void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, - const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params); + std::variant< + std::function, + std::function + > packed_size; + std::variant< + std::function, + std::function + > pack_func; }; struct ggml_kleidiai_kernels { @@ -55,6 +85,10 @@ struct ggml_kleidiai_kernels { rhs_packing_info rhs_info; cpu_feature required_cpu; + ggml_type lhs_type; + ggml_type rhs_type; + ggml_type op_type; }; -ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features); +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor); +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features); diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 4e89ca0faa2..f3dffdd6bf5 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -34,8 +34,9 @@ #include "ggml-common.h" struct ggml_kleidiai_context { + cpu_feature features; ggml_kleidiai_kernels * kernels; -} static ctx = { NULL }; +} static ctx = { CPU_FEATURE_NONE, NULL }; static void init_kleidiai_context(void) { @@ -47,18 +48,18 @@ static void init_kleidiai_context(void) { const char *env_var = getenv("GGML_KLEIDIAI_SME"); int sme_enabled = 0; - cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | - (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) | - (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); + ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | + (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) | + (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); if (env_var) { sme_enabled = atoi(env_var); } if (sme_enabled != 0) { - features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; + ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; } - ctx.kernels = ggml_kleidiai_select_kernels(features); + ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features); } ggml_critical_section_end(); } @@ -68,95 +69,275 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { return tensor->ne[dim]; } +template +static Ret variant_call(const Variant & var, Args&&... args) { + return std::visit([&](auto&& func) -> Ret { + if constexpr (std::is_invocable_r_v) { + return func(std::forward(args)...); + } else { + throw std::runtime_error("Invalid function type in variant_call"); + } + }, var); +} + namespace ggml::cpu::kleidiai { + +static size_t round_down(size_t x, size_t y) { + return y == 0 ? x : x - (x % y); +} + +static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) { + size_t src_stride = rhs_stride / sizeof(uint16_t); + size_t dst_stride = n; + + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + for (size_t n_idx = 0; n_idx < n; ++n_idx) { + uint16_t v = *(src + k_idx + n_idx * src_stride); + *(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v); + } + } +} + class tensor_traits : public ggml::cpu::tensor_traits { bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { - GGML_ASSERT(ctx.kernels); - kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm; + ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op); + GGML_ASSERT(kernels); + kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; size_t k = op->src[0]->ne[0]; + size_t n = op->src[0]->ne[1]; size_t m = op->src[1]->ne[1]; size_t mr = kernel->get_mr(); size_t kr = kernel->get_kr(); size_t sr = kernel->get_sr(); - size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr); + if (kernels->rhs_type == GGML_TYPE_Q4_0) { + size = variant_call(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr); + } else if (kernels->rhs_type == GGML_TYPE_F16) { + size = variant_call(kernels->lhs_info.packed_size, m, k, mr, kr, sr) + + variant_call(kernels->rhs_info.packed_size, n, k) + + k * n * sizeof(float) + n * sizeof(float); + } else { + GGML_ASSERT(false); + } return true; } + bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override { if (dst->op == GGML_OP_MUL_MAT) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; + if (dst->src[0]->type == GGML_TYPE_Q4_0) { + return compute_forward_q4_0(params, dst); + } else if (dst->src[0]->type == GGML_TYPE_F16) { + return compute_forward_kv_cache(params, dst); + } + } + return false; + } - GGML_TENSOR_BINARY_OP_LOCALS + bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) { + static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT; - GGML_ASSERT(ctx.kernels); - kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm; - lhs_packing_info * lhs_info = &ctx.kernels->lhs_info; + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(kernel); + GGML_TENSOR_BINARY_OP_LOCALS - const int ith = params->ith; - const int nth = params->nth; + ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); + GGML_ASSERT(kernels); - const size_t k = ne00; - const size_t m = ne11; - const size_t n = ne01; + kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; + GGML_ASSERT(kernel); - const size_t n_step = kernel->get_n_step(); - const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); - const size_t n_start = ith * num_n_per_thread; + const int nth = params->nth; + const int ith = params->ith; - size_t n_to_process = num_n_per_thread; - if ((n_start + n_to_process) > n) { - n_to_process = n - n_start; - } + const int64_t lhs_batch_size0 = ne12; + const int64_t rhs_batch_size0 = ne02; + const int64_t batch_size = rhs_batch_size0; + + const int64_t r = lhs_batch_size0 / rhs_batch_size0; + + const int64_t m = ne11 * r; + const int64_t n = ne01; + const int64_t k = ne00; + + const size_t lhs_stride = src1->nb[1]; + const size_t rhs_stride = src0->nb[1]; + const size_t dst_stride = dst->nb[1]; + + const int64_t mr = static_cast(kernel->get_mr()); + const int64_t nr = static_cast(kernel->get_nr()); + const int64_t kr = static_cast(kernel->get_kr()); + const int64_t sr = static_cast(kernel->get_sr()); + + const size_t lhs_packed_size = variant_call(kernels->lhs_info.packed_size, m, k, mr, kr, sr); + const size_t rhs_packed_size = variant_call(kernels->rhs_info.packed_size, n, k); + const size_t kxn_size = k * n * sizeof(float); + const size_t bias_size = n * sizeof(float); + + const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size; + GGML_ASSERT(wsize_required <= params->wsize); + + uint8_t * lhs_packed = static_cast(params->wdata); + uint8_t * rhs_packed = lhs_packed + lhs_packed_size; + uint8_t * rhs_kxn = rhs_packed + rhs_packed_size; + uint8_t * bias = rhs_kxn + kxn_size; + + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const uint8_t * lhs_batch = static_cast(src1->data) + batch_idx * m * lhs_stride; + const uint8_t * rhs_batch = static_cast(src0->data) + batch_idx * n * rhs_stride; + uint8_t * dst_batch = static_cast(dst->data) + batch_idx * m * dst_stride; - const uint8_t * lhs = static_cast(src1->data); - uint8_t * lhs_packed = (uint8_t*)params->wdata; - const uint8_t * rhs_packed = static_cast(src0->data); + // LHS packing + { + const int64_t m_roundup_mr = kai_roundup(m, mr); + const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth); - size_t mr = kernel->get_mr(); - size_t kr = kernel->get_kr(); - size_t sr = kernel->get_sr(); + if (ith < num_threads) { + const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr); + const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0; - // Calculate number of columns to be processed per thread - const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth; - const size_t m_start = ith * num_m_per_thread; - size_t m_to_process = num_m_per_thread; - if ((m_start + m_to_process) > m) { - m_to_process = m - m_start; + const int64_t m_start = ith * num_m_per_thread0; + const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; + + const size_t lhs_offset = variant_call(kernels->gemm.get_lhs_offset, m_start, lhs_stride); + const size_t lhs_packed_offset = variant_call(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr); + + const void * src_ptr = static_cast(lhs_batch) + lhs_offset; + void * dst_ptr = static_cast(lhs_packed) + lhs_packed_offset; + + variant_call(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr); + } } - if(m_start < m) { - // Transform LHS - const size_t src_stride = src1->nb[1]; - const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); - const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr); - void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); + // RHS packing + if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) { + // First thread to reach this point handles RHS packing + memset(bias, 0, n * sizeof(float)); + transpose_f32kxn_f16nxk(n, k, reinterpret_cast(rhs_kxn), + reinterpret_cast(rhs_batch), rhs_stride); - lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); + variant_call(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float), + rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr); } ggml_barrier(params->threadpool); - // Perform the operation - const size_t dst_stride = dst->nb[1]; - const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr); - const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0); - const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); - const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); - const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); - float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); - - kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, - dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); - return true; + first_to_arrive.clear(std::memory_order_release); + + // Perform the matmul + { + const int64_t m_to_process = m; + const int64_t m_start = 0; + + const int64_t n_step = static_cast(kernel->get_n_step()); + const int64_t num_threads = KAI_MIN(n / n_step, nth); + + if (ith < num_threads) { + const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step); + const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0; + + const int64_t n_start = ith * num_n_per_thread0; + const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0; + + const size_t lhs_packed_offset = variant_call(kernel->get_lhs_offset, m_start, k); + const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, n_start, k); + const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride); + + const void * lhs_ptr = lhs_packed + lhs_packed_offset; + const void * rhs_ptr = rhs_packed + rhs_packed_offset; + float * dst_ptr = reinterpret_cast(dst_batch + dst_offset); + + variant_call(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + } + } + + if (batch_idx != batch_size - 1) { + // This barrier is necessary when the batch size is larger than 1. While processing a batch, + // the work data buffer (params->wdata) is used as temporary storage which means that only + // a single batch can be processed at any given time. No barrier is needed for the last + // batch since GGML inserts a barrier between the execution of every operator. + ggml_barrier(params->threadpool); + } } - return false; + + return true; + } + + bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); + GGML_ASSERT(kernels); + + kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; + lhs_packing_info * lhs_info = &kernels->lhs_info; + + GGML_ASSERT(kernel); + + const int ith = params->ith; + const int nth = params->nth; + + const size_t k = ne00; + const size_t m = ne11; + const size_t n = ne01; + + size_t mr = kernel->get_mr(); + size_t kr = kernel->get_kr(); + size_t sr = kernel->get_sr(); + + const uint8_t * lhs = static_cast(src1->data); + uint8_t * lhs_packed = (uint8_t*)params->wdata; + const uint8_t * rhs_packed = static_cast(src0->data); + + const size_t n_step = kernel->get_n_step(); + const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); + const size_t n_start = ith * num_n_per_thread; + + size_t n_to_process = num_n_per_thread; + if ((n_start + n_to_process) > n) { + n_to_process = n - n_start; + } + + // Calculate number of columns to be processed per thread + const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth; + const size_t m_start = ith * num_m_per_thread; + size_t m_to_process = num_m_per_thread; + if ((m_start + m_to_process) > m) { + m_to_process = m - m_start; + } + + if (m_start < m) { + // Transform LHS + const size_t src_stride = src1->nb[1]; + const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); + const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr); + void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); + + variant_call(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); + } + + ggml_barrier(params->threadpool); + + // Perform the operation + const size_t dst_stride = dst->nb[1]; + const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr); + const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, n_start, k, QK4_0); + const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); + const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); + const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); + float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); + + variant_call(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, + sizeof(float), -FLT_MAX, FLT_MAX); + + return true; } public: @@ -169,13 +350,13 @@ class tensor_traits : public ggml::cpu::tensor_traits { size_t sr = ctx.kernels->gemm.get_sr(); #ifndef NDEBUG - const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0); + const size_t repacked_size = variant_call(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0); GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!"); #endif struct kai_rhs_pack_qs4cxs1s0_param params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; - ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, ¶ms); + variant_call(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms); return 0; @@ -189,7 +370,7 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc } } // namespace ggml::cpu::kleidiai -GGML_API enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { +static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor); GGML_UNUSED(buffer); @@ -238,12 +419,11 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b namespace ggml::cpu::kleidiai { class extra_buffer_type : ggml::cpu::extra_buffer_type { bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { - if ( op->op == GGML_OP_MUL_MAT && - op->src[0]->type == GGML_TYPE_Q4_0 && - op->src[0]->buffer && - (ggml_n_dims(op->src[0]) == 2) && - op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels - ) { + if (op->op == GGML_OP_MUL_MAT && + op->src[0]->type == GGML_TYPE_Q4_0 && + op->src[0]->buffer && + (ggml_n_dims(op->src[0]) == 2) && + op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) { if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { return false; } @@ -260,6 +440,19 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } + else if (ggml_kleidiai_select_kernels(ctx.features, op) && + op->src[0]->op == GGML_OP_VIEW && + (op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) && + op->src[1]->ne[1] > 1) { + if ((op->src[0]->nb[0] != 2) || + (op->src[1]->nb[0] != 4) || + (op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || + (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) { + return nullptr; + } + + return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL); + } } return nullptr; } From fe0d52b9a20558aa4b2b64fea6a2c948cc4192cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 12 May 2025 14:44:49 +0200 Subject: [PATCH 15/21] llama/ggml: add LLM training support (llama/10544) * llama/ggml: add LLM training support more compact progress bar llama_save_model_to_file llama_opt_param_filter ggml_graph_dup force_grads refactor ggml_opt, fix test-opt * remove logits_all * refactor CUDA implementation for ACC * reset graph at beginning of opt period --- ggml/include/ggml-opt.h | 75 +++-- ggml/include/ggml.h | 13 +- ggml/src/ggml-backend.cpp | 2 +- ggml/src/ggml-cuda/acc.cu | 66 +++-- ggml/src/ggml-cuda/sum.cu | 2 +- ggml/src/ggml-opt.cpp | 558 +++++++++++++++++++++++++------------- ggml/src/ggml.c | 41 +-- 7 files changed, 486 insertions(+), 271 deletions(-) diff --git a/ggml/include/ggml-opt.h b/ggml/include/ggml-opt.h index eb5eab9de67..da0c24b46fe 100644 --- a/ggml/include/ggml-opt.h +++ b/ggml/include/ggml-opt.h @@ -37,13 +37,16 @@ extern "C" { // ====== Dataset ====== GGML_API ggml_opt_dataset_t ggml_opt_dataset_init( - int64_t ne_datapoint, // number of elements per datapoint - int64_t ne_label, // number of elements per label - int64_t ndata, // total number of datapoints/labels - int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) + enum ggml_type type_data, // the type for the internal data tensor + enum ggml_type type_label, // the type for the internal labels tensor + int64_t ne_datapoint, // number of elements per datapoint + int64_t ne_label, // number of elements per label + int64_t ndata, // total number of datapoints/labels + int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset); // get underlying tensors that store the data + GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset); GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata] GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata] @@ -56,13 +59,19 @@ extern "C" { struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch] struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch] int64_t ibatch); + GGML_API void ggml_opt_dataset_get_batch_host( + ggml_opt_dataset_t dataset, + void * data_batch, + size_t nb_data_batch, + void * labels_batch, + int64_t ibatch); // ====== Model / Context ====== enum ggml_opt_build_type { - GGML_OPT_BUILD_TYPE_FORWARD, - GGML_OPT_BUILD_TYPE_GRAD, - GGML_OPT_BUILD_TYPE_OPT, + GGML_OPT_BUILD_TYPE_FORWARD = 10, + GGML_OPT_BUILD_TYPE_GRAD = 20, + GGML_OPT_BUILD_TYPE_OPT = 30, }; // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss @@ -81,20 +90,22 @@ extern "C" { // userdata can be used to pass arbitrary data typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata); - // returns the default optimizer params (constant) + // returns the default optimizer params (constant, hard-coded values) // userdata is not used GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata); + // casts userdata to ggml_opt_optimizer_params and returns it + GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata); + // parameters for initializing a new optimization context struct ggml_opt_params { ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs - struct ggml_context * ctx_compute; // created in user code, holds non-static tensors - - // the forward graph is defined by inputs and outputs - // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts - struct ggml_tensor * inputs; - struct ggml_tensor * outputs; + // by default the forward graph needs to be reconstructed for each eval + // if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically + struct ggml_context * ctx_compute; + struct ggml_tensor * inputs; + struct ggml_tensor * outputs; enum ggml_opt_loss_type loss_type; enum ggml_opt_build_type build_type; @@ -107,12 +118,9 @@ extern "C" { // get parameters for an optimization context with defaults set where possible // parameters for which no sensible defaults exist are supplied as arguments to this function - GGML_API ggml_opt_params ggml_opt_default_params( - ggml_backend_sched_t backend_sched, - struct ggml_context * ctx_compute, - struct ggml_tensor * inputs, - struct ggml_tensor * outputs, - enum ggml_opt_loss_type loss_type); + GGML_API struct ggml_opt_params ggml_opt_default_params( + ggml_backend_sched_t backend_sched, + enum ggml_opt_loss_type loss_type); GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params); GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx); @@ -121,6 +129,7 @@ extern "C" { GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer); // get underlying tensors that store data + // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against @@ -128,11 +137,12 @@ extern "C" { GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels + // get the gradient accumulator for a node from the forward graph GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node); // ====== Optimization Result ====== - GGML_API ggml_opt_result_t ggml_opt_result_init(); + GGML_API ggml_opt_result_t ggml_opt_result_init(void); GGML_API void ggml_opt_result_free(ggml_opt_result_t result); GGML_API void ggml_opt_result_reset(ggml_opt_result_t result); @@ -144,11 +154,20 @@ extern "C" { // ====== Computation ====== - // do forward pass, increment result if not NULL - GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); + // if not using static graphs, this function must be called prior to ggml_opt_alloc + GGML_API void ggml_opt_prepare_alloc( + ggml_opt_context_t opt_ctx, + struct ggml_context * ctx_compute, + struct ggml_cgraph * gf, + struct ggml_tensor * inputs, + struct ggml_tensor * outputs); + + // allocate the next graph for evaluation, either forward or forward + backward + // must be called exactly once prior to calling ggml_opt_eval + GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward); - // do forward pass, increment result if not NULL, do backward pass - GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); + // do forward pass, increment result if not NULL, do backward pass if allocated + GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); // ############################################################################ // ## The high-level functions start here. They do not depend on any private ## @@ -200,9 +219,9 @@ extern "C" { // fit model defined by inputs and outputs to dataset GGML_API void ggml_opt_fit( ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs - ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs - ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch] - ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used + struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs + struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch] + struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used ggml_opt_dataset_t dataset, // dataset with data and optionally also labels enum ggml_opt_loss_type loss_type, // loss to minimize ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index c518366d58a..e91dedf14a1 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -768,7 +768,7 @@ extern "C" { // Tensor flags GGML_API void ggml_set_input(struct ggml_tensor * tensor); GGML_API void ggml_set_output(struct ggml_tensor * tensor); - GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor); + GGML_API void ggml_set_param(struct ggml_tensor * tensor); GGML_API void ggml_set_loss(struct ggml_tensor * tensor); // @@ -938,7 +938,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_repeat_back( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride // concat a and b along dim // used in stable-diffusion @@ -2049,15 +2049,14 @@ extern "C" { GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); GGML_API void ggml_build_backward_expand( - struct ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation) - struct ggml_context * ctx_compute, // context for gradient computation - struct ggml_cgraph * cgraph, - bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static + struct ggml_context * ctx, // context for gradient computation + struct ggml_cgraph * cgraph, + struct ggml_tensor ** grad_accs); // graph allocation in a context GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads); - GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph); + GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads); GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst); GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1 GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 6f69d895f17..b30b4cb386f 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1111,7 +1111,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg const int node_backend_id = tensor_backend_id(node); - assert(node_backend_id != -1); // all nodes should be assigned by now + assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback // check if we should start a new split based on the sources of the current node bool need_new_split = false; diff --git a/ggml/src/ggml-cuda/acc.cu b/ggml/src/ggml-cuda/acc.cu index 96bfe1c9d81..e084607c029 100644 --- a/ggml/src/ggml-cuda/acc.cu +++ b/ggml/src/ggml-cuda/acc.cu @@ -1,47 +1,61 @@ #include "acc.cuh" -static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne, - const int ne10, const int ne11, const int ne12, - const int nb1, const int nb2, int offset) { - const int i = blockDim.x * blockIdx.x + threadIdx.x; +static __global__ void acc_f32(const float * x, const float * y, float * dst, const int64_t ne, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) { + const int64_t i = blockDim.x * blockIdx.x + threadIdx.x; + if (i >= ne) { return; } - int src1_idx = i - offset; - int oz = src1_idx / nb2; - int oy = (src1_idx - (oz * nb2)) / nb1; - int ox = src1_idx % nb1; - if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) { - dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11]; - } else { - dst[i] = x[i]; + + int64_t src1_idx = i - offset; + + int64_t tmp = src1_idx; + const int64_t i13 = tmp / s13; + tmp -= i13 * s13; + const int64_t i12 = tmp / s12; + tmp -= i12 * s12; + const int64_t i11 = tmp / s11; + tmp -= i11 * s11; + const int64_t i10 = tmp; + + float val = x[i]; + if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) { + val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10]; } + dst[i] = val; } -static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements, - const int ne10, const int ne11, const int ne12, - const int nb1, const int nb2, const int offset, cudaStream_t stream) { - int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE; - acc_f32<<>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset); +static void acc_f32_cuda(const float * x, const float * y, float * dst, const int64_t n_elements, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s1, const int64_t s2, const int64_t s3, const int64_t offset, cudaStream_t stream) { + const int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE; + acc_f32<<>>(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset); } void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - float * dst_d = (float *)dst->data; + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported - int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 - int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 - // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused - int offset = dst->op_params[3] / 4; // offset in bytes + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(dst->nb[0] == ggml_element_size(dst)); + GGML_ASSERT(ggml_is_contiguously_allocated(dst)); + + const int64_t s1 = dst->op_params[0] / sizeof(float); + const int64_t s2 = dst->op_params[1] / sizeof(float); + const int64_t s3 = dst->op_params[2] / sizeof(float); + const int64_t offset = dst->op_params[3] / sizeof(float); - acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, stream); + acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], s1, s2, s3, offset, stream); } diff --git a/ggml/src/ggml-cuda/sum.cu b/ggml/src/ggml-cuda/sum.cu index f9589080a0c..eb3d7cdba98 100644 --- a/ggml/src/ggml-cuda/sum.cu +++ b/ggml/src/ggml-cuda/sum.cu @@ -31,7 +31,7 @@ void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); const float * src0_d = (const float *) src0->data; float * dst_d = (float *) dst->data; diff --git a/ggml/src/ggml-opt.cpp b/ggml/src/ggml-opt.cpp index 7c3e24103a2..58d77578f45 100644 --- a/ggml/src/ggml-opt.cpp +++ b/ggml/src/ggml-opt.cpp @@ -28,16 +28,19 @@ struct ggml_opt_dataset { }; struct ggml_opt_context { - ggml_backend_sched_t backend_sched = nullptr; - ggml_cgraph * allocated_graph = nullptr; - ggml_cgraph * allocated_graph_copy = nullptr; - struct ggml_context * ctx_static = nullptr; - struct ggml_context * ctx_static_cpu = nullptr; - struct ggml_context * ctx_compute = nullptr; - struct ggml_context * ctx_copy = nullptr; - ggml_backend_buffer_t buf_static = nullptr; - ggml_backend_buffer_t buf_static_cpu = nullptr; - std::mt19937 rng; + ggml_backend_sched_t backend_sched = nullptr; + ggml_cgraph * allocated_graph = nullptr; + ggml_cgraph * allocated_graph_copy = nullptr; + struct ggml_context * ctx_static = nullptr; + struct ggml_context * ctx_cpu = nullptr; + struct ggml_context * ctx_compute = nullptr; + struct ggml_context * ctx_copy = nullptr; + ggml_backend_buffer_t buf_static = nullptr; + ggml_backend_buffer_t buf_cpu = nullptr; + std::mt19937 rng; + enum ggml_opt_loss_type loss_type; + enum ggml_opt_build_type build_type; + enum ggml_opt_build_type build_type_alloc; struct ggml_tensor * inputs = nullptr; struct ggml_tensor * outputs = nullptr; @@ -50,6 +53,11 @@ struct ggml_opt_context { struct ggml_cgraph * gf = nullptr; struct ggml_cgraph * gb_grad = nullptr; struct ggml_cgraph * gb_opt = nullptr; + bool static_graphs = false; + bool eval_ready = false; + std::vector grad_accs; + std::vector grad_m; + std::vector grad_v; int64_t iter = 1; int32_t opt_period = 1; @@ -73,7 +81,13 @@ struct ggml_opt_result { // ====== Dataset ====== -ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) { +ggml_opt_dataset_t ggml_opt_dataset_init( + enum ggml_type type_data, + enum ggml_type type_label, + int64_t ne_datapoint, + int64_t ne_label, + int64_t ndata, + int64_t ndata_shard) { GGML_ASSERT(ne_datapoint > 0); GGML_ASSERT(ne_label >= 0); GGML_ASSERT(ndata > 0); @@ -92,11 +106,11 @@ ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, result->ctx = ggml_init(params); } - result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata); + result->data = ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata); result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata; if (ne_label > 0) { - result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata); + result->labels = ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata); result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata; } else { result->labels = nullptr; @@ -119,6 +133,10 @@ void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) { delete dataset; } +int64_t ggml_opt_dataset_ndata(ggml_opt_dataset_t dataset) { + return dataset->ndata; +} + struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) { return dataset->data; } @@ -144,6 +162,8 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch)); GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch)); GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr)); + GGML_ASSERT( data_batch->type == dataset->data->type); + GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type); const size_t nb_data_batch = ggml_nbytes(data_batch); GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0); @@ -171,6 +191,31 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * } } +void ggml_opt_dataset_get_batch_host(ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) { + GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr)); + GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0); + + const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data; + + GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size())); + + for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) { + const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch]; + + const char * ptr_data = (const char *) dataset->data->data + ishard *dataset->nbs_data; + char * ptr_data_batch = (char *) data_batch + ishard_batch*dataset->nbs_data; + memcpy(ptr_data_batch, ptr_data, dataset->nbs_data); + + if (!labels_batch) { + continue; + } + + const char * ptr_labels = (const char *) dataset->labels->data + ishard *dataset->nbs_labels; + char * ptr_labels_batch = (char *) labels_batch + ishard_batch*dataset->nbs_labels; + memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels); + } +} + // ====== Model / Context ====== struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) { @@ -187,17 +232,18 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us return result; } +struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) { + return *((struct ggml_opt_optimizer_params *) userdata); +} + struct ggml_opt_params ggml_opt_default_params( ggml_backend_sched_t backend_sched, - struct ggml_context * ctx_compute, - struct ggml_tensor * inputs, - struct ggml_tensor * outputs, enum ggml_opt_loss_type loss_type) { return { /*backend_sched =*/ backend_sched, - /*ctx_compute =*/ ctx_compute, - /*inputs =*/ inputs, - /*logits =*/ outputs, + /*ctx_compute =*/ nullptr, + /*inputs =*/ nullptr, + /*logits =*/ nullptr, /*loss_type =*/ loss_type, /*build_type =*/ GGML_OPT_BUILD_TYPE_OPT, /*opt_period =*/ 1, @@ -266,195 +312,246 @@ static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) { return dst; } -static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) { - GGML_ASSERT(graph); - if (opt_ctx->allocated_graph == graph) { - return; - } - - ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph - - { - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE, - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - ggml_free(opt_ctx->ctx_copy); - opt_ctx->ctx_copy = ggml_init(params); - } - - opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph); - - ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); - opt_ctx->allocated_graph = graph; -} - -ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) { - ggml_opt_context_t result = new struct ggml_opt_context; - result->backend_sched = params.backend_sched; - result->ctx_compute = params.ctx_compute; - result->inputs = params.inputs; - result->outputs = params.outputs; - result->opt_period = params.opt_period; - result->get_opt_pars = params.get_opt_pars; - result->get_opt_pars_ud = params.get_opt_pars_ud; - - GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically"); - GGML_ASSERT(result->opt_period >= 1); - - const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD || - (params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1); +static void ggml_opt_build(ggml_opt_context_t opt_ctx) { + GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc"); + GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically"); - ggml_set_input(result->inputs); - ggml_set_output(result->outputs); + const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD && + !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1); - result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass. - ggml_build_forward_expand(result->gf, result->outputs); + ggml_set_input(opt_ctx->inputs); + ggml_set_output(opt_ctx->outputs); int n_param = 0; - for (int i = 0; i < result->gf->n_nodes; ++i) { - if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) { + for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) { + const struct ggml_tensor * node = opt_ctx->gf->nodes[i]; + if (node->flags & GGML_TENSOR_FLAG_PARAM) { n_param++; } + GGML_ASSERT(!(node->flags & GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented"); } - { + if (!opt_ctx->ctx_static) { // The static context is used for: - // - gradients (1 tensor per param if using gradient accumulation) + // - gradients (1 per loss, 1 tensor per param if using gradient accumulation) // - optimizer momenta (2 tensors per param) - // - labels - // - loss + its gradient (up to 5 tensors) - // - pred - // - ncorrect (2 tensors). - const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0); - const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead(); + // - labels (if using static graphs) + // - loss (if using static graphs, up to 5 tensors) + // - pred (if using static graphs) + // - ncorrect (if using static graphs, 2 tensors). + constexpr size_t n_loss = 1; + const size_t tensors_per_param = (accumulate ? 1 : 0) + + (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0); + const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0; + const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead(); struct ggml_init_params params = { /*.mem_size =*/ size_meta, /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; - result->ctx_static = ggml_init(params); + opt_ctx->ctx_static = ggml_init(params); } + GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc); + { - // The static cpu context is used for: - // - optimizer parameters (1 for the entire context) + // The cpu context is allocated statically if using static graphs, dynamically otherwise. + // It is used for: + // - optimizer parameters (1 shared for all optimizer invocations) const size_t size_meta = 1 * ggml_tensor_overhead(); struct ggml_init_params params = { /*.mem_size =*/ size_meta, /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; - result->ctx_static_cpu = ggml_init(params); + ggml_free(opt_ctx->ctx_cpu); + opt_ctx->ctx_cpu = ggml_init(params); + + ggml_backend_buffer_free(opt_ctx->buf_cpu); + opt_ctx->buf_cpu = nullptr; } + struct ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute; - switch (params.loss_type) { + switch (opt_ctx->loss_type) { case GGML_OPT_LOSS_TYPE_MEAN: { - result->loss = ggml_sum(result->ctx_static, result->outputs); - ggml_set_name(result->loss, "loss_sum"); - const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs)); - result->loss = ggml_scale(result->ctx_static, result->loss, scale); - ggml_set_name(result->loss, "loss_mean"); - result->loss_per_datapoint = true; + opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs); + ggml_set_name(opt_ctx->loss, "loss_sum"); + const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs)); + opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale); + ggml_set_name(opt_ctx->loss, "loss_mean"); + opt_ctx->loss_per_datapoint = true; break; } case GGML_OPT_LOSS_TYPE_SUM: { - result->loss = ggml_sum(result->ctx_static, result->outputs); - ggml_set_name(result->loss, "loss_sum"); - result->loss_per_datapoint = false; + opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs); + ggml_set_name(opt_ctx->loss, "loss_sum"); + opt_ctx->loss_per_datapoint = false; break; } case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: { - result->labels = ggml_dup_tensor(result->ctx_static, result->outputs); - ggml_set_input(result->labels); - ggml_set_name(result->labels, "labels"); - result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels); - ggml_set_name(result->loss, "loss_cross_entropy"); - if (result->opt_period > 1) { - result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period); - ggml_set_name(result->loss, "loss_cross_entropy_scaled"); + opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs); + ggml_set_input(opt_ctx->labels); + ggml_set_name(opt_ctx->labels, "labels"); + opt_ctx->loss = ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels); + ggml_set_name(opt_ctx->loss, "loss_cross_entropy"); + if (opt_ctx->opt_period > 1) { + opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period); + ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled"); } - result->loss_per_datapoint = true; + opt_ctx->loss_per_datapoint = true; break; } case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: { - result->labels = ggml_dup_tensor(result->ctx_static, result->outputs); - ggml_set_input(result->labels); - ggml_set_name(result->labels, "labels"); - result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels); - ggml_set_name(result->loss, "loss_error"); - result->loss = ggml_sqr(result->ctx_static, result->loss); - ggml_set_name(result->loss, "loss_squared_error"); - result->loss = ggml_sum(result->ctx_static, result->loss); - ggml_set_name(result->loss, "loss_sum_squared_error"); - const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs)); - result->loss = ggml_scale(result->ctx_static, result->loss, scale); - ggml_set_name(result->loss, "loss_mean_squared_error"); - result->loss_per_datapoint = true; + opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs); + ggml_set_input(opt_ctx->labels); + ggml_set_name(opt_ctx->labels, "labels"); + opt_ctx->loss = ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels); + ggml_set_name(opt_ctx->loss, "loss_error"); + opt_ctx->loss = ggml_sqr(ctx_results, opt_ctx->loss); + ggml_set_name(opt_ctx->loss, "loss_squared_error"); + opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->loss); + ggml_set_name(opt_ctx->loss, "loss_sum_squared_error"); + const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs)); + opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale); + ggml_set_name(opt_ctx->loss, "loss_mean_squared_error"); + opt_ctx->loss_per_datapoint = true; break; } } - ggml_set_output(result->loss); - ggml_set_loss(result->loss); - ggml_build_forward_expand(result->gf, result->loss); - - result->pred = ggml_argmax(result->ctx_static, result->outputs); - ggml_set_name(result->pred, "pred"); - ggml_set_output(result->pred); - ggml_build_forward_expand(result->gf, result->pred); + ggml_set_output(opt_ctx->loss); + ggml_set_loss(opt_ctx->loss); + ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss); + + if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) { + opt_ctx->pred = ggml_argmax(ctx_results, opt_ctx->outputs); + ggml_set_name(opt_ctx->pred, "pred"); + ggml_set_output(opt_ctx->pred); + ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred); + + opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels)); + ggml_set_name(opt_ctx->ncorrect, "ncorrect"); + ggml_set_output(opt_ctx->ncorrect); + ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect); + } - if (result->labels) { - result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels)); - ggml_set_name(result->ncorrect, "ncorrect"); - ggml_set_output(result->ncorrect); - ggml_build_forward_expand(result->gf, result->ncorrect); - } else { - result->ncorrect = nullptr; + if (opt_ctx->buf_static) { + if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_FORWARD) { + return; + } + } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_FORWARD) { + opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors( + opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0)); + return; } - if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) { - result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); - return result; + if (opt_ctx->grad_accs.empty()) { + GGML_ASSERT(opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD); + + const int n_nodes = opt_ctx->gf->n_nodes; + opt_ctx->grad_accs.resize(n_nodes); + for (int i = 0; i < n_nodes; ++i) { + ggml_tensor * node = opt_ctx->gf->nodes[i]; + if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) { + opt_ctx->grad_accs[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + } else { + opt_ctx->grad_accs[i] = nullptr; + } + } + + if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) { + opt_ctx->grad_m.resize(n_nodes); + opt_ctx->grad_v.resize(n_nodes); + for (int i = 0; i < n_nodes; ++i) { + ggml_tensor * node = opt_ctx->gf->nodes[i]; + if (node->flags & GGML_TENSOR_FLAG_PARAM) { + opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + } else { + opt_ctx->grad_m[i] = nullptr; + opt_ctx->grad_v[i] = nullptr; + } + } + } } // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients. - result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf); - ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate); + opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true); + ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data()); - if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) { - result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); - ggml_graph_reset(result->gb_grad); - return result; + if (opt_ctx->buf_static) { + if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_GRAD) { + return; + } + } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_GRAD) { + opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0)); + ggml_graph_reset(opt_ctx->gb_grad); } - GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT); + GGML_ASSERT(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT); // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step. - result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad); + opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true); - result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7); - ggml_set_input(result->adamw_params); - ggml_set_name(result->adamw_params, "adamw_params"); + opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7); + ggml_set_input(opt_ctx->adamw_params); + ggml_set_name(opt_ctx->adamw_params, "adamw_params"); - for (int i = result->gf->n_nodes-1; i >= 0; --i) { - struct ggml_tensor * node = result->gb_opt->nodes[i]; - struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node); + for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) { + struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node); - if (node->flags & GGML_TENSOR_FLAG_PARAM) { - struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node); - struct ggml_tensor * v = ggml_dup_tensor(result->ctx_static, node); - struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params); - ggml_build_forward_expand(result->gb_opt, opt_step); + if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) { + struct ggml_tensor * m = opt_ctx->grad_m[i]; + struct ggml_tensor * v = opt_ctx->grad_v[i]; + struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params); + + ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str()); + ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str()); + ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str()); + + ggml_build_forward_expand(opt_ctx->gb_opt, opt_step); } } - result->buf_static = ggml_backend_alloc_ctx_tensors( - result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); + if (!opt_ctx->buf_static) { + opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors( + opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0)); + ggml_graph_reset(opt_ctx->gb_opt); + } - result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type()); + opt_ctx->buf_cpu = ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, ggml_backend_cpu_buffer_type()); +} - ggml_graph_reset(result->gb_opt); +ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) { + ggml_opt_context_t result = new struct ggml_opt_context; + result->backend_sched = params.backend_sched; + result->ctx_compute = params.ctx_compute; + result->loss_type = params.loss_type; + result->build_type = params.build_type; + result->build_type_alloc = params.build_type; + result->inputs = params.inputs; + result->outputs = params.outputs; + result->opt_period = params.opt_period; + result->get_opt_pars = params.get_opt_pars; + result->get_opt_pars_ud = params.get_opt_pars_ud; + + GGML_ASSERT(result->opt_period >= 1); + + result->static_graphs = result->ctx_compute; + + if (!result->static_graphs) { + GGML_ASSERT(!result->inputs); + GGML_ASSERT(!result->outputs); + return result; + } + + GGML_ASSERT(result->inputs); + GGML_ASSERT(result->outputs); + + result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass. + ggml_build_forward_expand(result->gf, result->outputs); + + ggml_opt_build(result); return result; } @@ -464,9 +561,9 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) { return; } ggml_backend_buffer_free(opt_ctx->buf_static); - ggml_backend_buffer_free(opt_ctx->buf_static_cpu); + ggml_backend_buffer_free(opt_ctx->buf_cpu); ggml_free(opt_ctx->ctx_static); - ggml_free(opt_ctx->ctx_static_cpu); + ggml_free(opt_ctx->ctx_cpu); delete opt_ctx; } @@ -582,8 +679,79 @@ void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, doubl // ====== Computation ====== -static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) { - if (graph != opt_ctx->gf) { +void ggml_opt_prepare_alloc( + ggml_opt_context_t opt_ctx, + struct ggml_context * ctx_compute, + struct ggml_cgraph * gf, + struct ggml_tensor * inputs, + struct ggml_tensor * outputs) { + GGML_ASSERT(!opt_ctx->static_graphs); + opt_ctx->ctx_compute = ctx_compute; + opt_ctx->gf = gf; + opt_ctx->inputs = inputs; + opt_ctx->outputs = outputs; +} + +void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) { + GGML_ASSERT(!opt_ctx->eval_ready); + if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) { + ggml_graph_reset(opt_ctx->gb_grad); + } + if (backward) { + const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; + opt_ctx->build_type = opt_i_next == 0 ? GGML_OPT_BUILD_TYPE_OPT : GGML_OPT_BUILD_TYPE_GRAD; + } else { + opt_ctx->build_type = GGML_OPT_BUILD_TYPE_FORWARD; + } + + if (!opt_ctx->static_graphs) { + ggml_opt_build(opt_ctx); + } + + struct ggml_cgraph * graph = nullptr; + switch (opt_ctx->build_type) { + case GGML_OPT_BUILD_TYPE_FORWARD: { + graph = opt_ctx->gf; + } break; + case GGML_OPT_BUILD_TYPE_GRAD: { + graph = opt_ctx->gb_grad; + } break; + case GGML_OPT_BUILD_TYPE_OPT: { + graph = opt_ctx->gb_opt; + } break; + } + GGML_ASSERT(graph); + + if (opt_ctx->allocated_graph == graph) { + opt_ctx->eval_ready = true; + return; + } + + ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph + + if (opt_ctx->static_graphs) { + ggml_init_params params = { + /*.mem_size =*/ graph->size*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph->size, graph->grads), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ggml_free(opt_ctx->ctx_copy); + opt_ctx->ctx_copy = ggml_init(params); + + opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph); + } else { + opt_ctx->allocated_graph_copy = graph; + } + + ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); + opt_ctx->allocated_graph = graph; + + opt_ctx->eval_ready = true; +} + +void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) { + GGML_ASSERT(opt_ctx->eval_ready); + if (opt_ctx->allocated_graph == opt_ctx->gb_opt) { struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud); GGML_ASSERT(opt_pars.adamw.alpha > 0.0f); @@ -609,9 +777,19 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, adamw_par_data[6] = beta2h; } - ggml_opt_alloc_graph(opt_ctx, graph); ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt; + opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; + + if (!opt_ctx->static_graphs) { + opt_ctx->gf = nullptr; + opt_ctx->gb_grad = nullptr; + opt_ctx->gb_opt = nullptr; + opt_ctx->allocated_graph = nullptr; + opt_ctx->allocated_graph_copy = nullptr; + } + + opt_ctx->eval_ready = false; if (!result) { return; @@ -635,12 +813,14 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss)); result->loss.push_back(loss); - GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32); - std::vector pred(ndata); - ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred)); - result->pred.insert(result->pred.end(), pred.begin(), pred.end()); + if (opt_ctx->pred) { + GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32); + std::vector pred(ndata); + ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred)); + result->pred.insert(result->pred.end(), pred.begin(), pred.end()); + } - if (!opt_ctx->labels || result->ncorrect < 0) { + if (!opt_ctx->ncorrect || result->ncorrect < 0) { result->ncorrect = -1; return; } @@ -652,26 +832,6 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, result->ncorrect += ncorrect; } -void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) { - ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result); -} - -void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) { - if (opt_ctx->opt_period == 1) { - ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result); - return; - } - - const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; - if (opt_i_next == 0) { - ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result); - ggml_opt_reset(opt_ctx, /*optimizer =*/ false); - } else { - ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result); - } - opt_ctx->opt_i = opt_i_next; -} - // ====== High-Level Functions ====== void ggml_opt_epoch( @@ -700,16 +860,18 @@ void ggml_opt_epoch( int64_t ibatch = 0; int64_t t_loop_start = ggml_time_us(); for (; ibatch < ibatch_split; ++ibatch) { + ggml_opt_alloc(opt_ctx, /*backward =*/ true); ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); - ggml_opt_forward_backward(opt_ctx, result_train); + ggml_opt_eval(opt_ctx, result_train); if (callback_train) { callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start); } } t_loop_start = ggml_time_us(); for (; ibatch < nbatches; ++ibatch) { + ggml_opt_alloc(opt_ctx, /*backward =*/ false); ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); - ggml_opt_forward(opt_ctx, result_eval); + ggml_opt_eval(opt_ctx, result_eval); if (callback_eval) { callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start); } @@ -726,13 +888,26 @@ void ggml_opt_epoch_callback_progress_bar( int64_t t_start_us) { fprintf(stderr, "%s[", train ? "train: " : "val: "); - constexpr int64_t bar_length = 25; + // The progress bar consists of partially filled blocks, unicode has 8 separate fill levels. + constexpr int64_t bar_length = 8; + const int64_t ibatch8 = 8 * ibatch; for (int64_t j = 0; j < bar_length; ++j) { - const int64_t ibatch_j = ibatch_max * j/bar_length; - if (ibatch_j < ibatch) { - fprintf(stderr, "="); - } else if (ibatch_max * (j - 1)/bar_length < ibatch) { - fprintf(stderr, ">"); + if (ibatch_max * (8*j + 8) / bar_length < ibatch8) { + fprintf(stderr, "\u2588"); // full block + } else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) { + fprintf(stderr, "\u2589"); // 7/8 filled + } else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) { + fprintf(stderr, "\u258A"); // 6/8 filled + } else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) { + fprintf(stderr, "\u258B"); // 5/8 filled + } else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) { + fprintf(stderr, "\u258C"); // 4/8 filled + } else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) { + fprintf(stderr, "\u258D"); // 3/8 filled + } else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) { + fprintf(stderr, "\u258E"); // 2/8 filled + } else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) { + fprintf(stderr, "\u258F"); // 1/8 filled } else { fprintf(stderr, " "); } @@ -764,8 +939,8 @@ void ggml_opt_epoch_callback_progress_bar( const int64_t t_eta_m = t_eta_s / 60; t_eta_s -= t_eta_m * 60; - fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, " - "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r", + fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% " + "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r", idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc, t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s); if (ibatch == ibatch_max) { @@ -806,7 +981,10 @@ void ggml_opt_fit( int64_t epoch = 1; - ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type); + ggml_opt_params params = ggml_opt_default_params(backend_sched, loss_type); + params.ctx_compute = ctx_compute; + params.inputs = inputs; + params.outputs = outputs; params.opt_period = opt_period; params.get_opt_pars = get_opt_pars; params.get_opt_pars_ud = &epoch; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ee4fe9f723d..01f7e05b287 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5499,7 +5499,7 @@ static void ggml_compute_backward( // tensor = src0 * 1 + src1 * 0 if (src0_needs_grads) { // dsrc0 = dtensor * 1 - ggml_add_or_set(ctx, cgraph, isrc0, grad); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0)); } if (src1_needs_grads) { // dsrc1 = dtensor * 0 -> noop @@ -5780,10 +5780,9 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * } void ggml_build_backward_expand( - struct ggml_context * ctx_static, - struct ggml_context * ctx_compute, - struct ggml_cgraph * cgraph, - bool accumulate) { + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + struct ggml_tensor ** grad_accs) { GGML_ASSERT(cgraph->n_nodes > 0); GGML_ASSERT(cgraph->grads); GGML_ASSERT(cgraph->grad_accs); @@ -5856,21 +5855,24 @@ void ggml_build_backward_expand( GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE); - const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); - GGML_ASSERT(igrad != GGML_HASHSET_FULL); - GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad)); - if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) { - cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node); - cgraph->grads[igrad] = cgraph->grad_accs[igrad]; - ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name); + const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node); + GGML_ASSERT(ihash != GGML_HASHSET_FULL); + GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash)); + if (grad_accs && grad_accs[i]) { + cgraph->grad_accs[ihash] = grad_accs[i]; + cgraph->grads[ihash] = cgraph->grad_accs[ihash]; + } else if (node->flags & GGML_TENSOR_FLAG_LOSS) { + // loss tensors always need a gradient accumulator + cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + cgraph->grads[ihash] = cgraph->grad_accs[ihash]; } - grads_needed[igrad] = true; + grads_needed[ihash] = true; } for (int i = n_nodes_f - 1; i >= 0; --i) { // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation // use allocator to automatically make inplace operations - ggml_compute_backward(ctx_compute, cgraph, i, grads_needed); + ggml_compute_backward(ctx, cgraph, i, grads_needed); } free(grads_needed); @@ -6016,8 +6018,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) { } } -struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { - struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL); +struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) { + struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads); ggml_graph_cpy(cgraph, result); return result; } @@ -6036,6 +6038,9 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { } void ggml_graph_reset(struct ggml_cgraph * cgraph) { + if (!cgraph) { + return; + } GGML_ASSERT(cgraph->grads != NULL); for (int i = 0; i < cgraph->n_nodes; i++) { @@ -6345,8 +6350,8 @@ void ggml_set_output(struct ggml_tensor * tensor) { tensor->flags |= GGML_TENSOR_FLAG_OUTPUT; } -void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) { - GGML_UNUSED(ctx); // TODO: remove this parameter +void ggml_set_param(struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->op == GGML_OP_NONE); tensor->flags |= GGML_TENSOR_FLAG_PARAM; } From 43a59eccf65ec228340fdcfc551107762d8ff8ba Mon Sep 17 00:00:00 2001 From: lhez Date: Mon, 12 May 2025 13:13:49 -0700 Subject: [PATCH 16/21] opencl: remove unnecessary assert for `add` (llama/13257) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 05a2f4e630a..58694604838 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -4855,8 +4855,6 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor if (!any_on_device) { return false; } - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); func = ggml_cl_add; break; case GGML_OP_MUL: From 926e06dbfd8cf0c9362658b32b861c1da93b1b04 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 13 May 2025 13:09:20 +0300 Subject: [PATCH 17/21] metal : optimize MoE for large batches (llama/13388) --- ggml/src/ggml-metal/ggml-metal-impl.h | 43 +- ggml/src/ggml-metal/ggml-metal.m | 733 ++++++++++++++++++++------ ggml/src/ggml-metal/ggml-metal.metal | 373 +++++++------ ggml/src/ggml.c | 4 +- 4 files changed, 822 insertions(+), 331 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 8721b272de6..ea8cf45863a 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -299,21 +299,42 @@ typedef struct { } ggml_metal_kargs_mul_mv_ext; typedef struct { - int32_t nei0; - int32_t nei1; - uint64_t nbi1; + int32_t ne10; + int32_t ne11; // n_expert_used (bcast) + uint64_t nb11; + uint64_t nb12; + int32_t neh11; // n_tokens + uint64_t nbh11; + int32_t ne20; // n_expert_used + uint64_t nb21; +} ggml_metal_kargs_mul_mm_id_map0; + +typedef struct { + int32_t ne20; // n_expert_used + int32_t neh0; + int32_t neh1; + uint64_t nbh1; + uint64_t nbh2; + int32_t ne0; + uint64_t nb1; + uint64_t nb2; +} ggml_metal_kargs_mul_mm_id_map1; + +typedef struct { int32_t ne00; int32_t ne02; uint64_t nb01; uint64_t nb02; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - int32_t ne0; - int32_t ne1; + uint64_t nb03; + int32_t neh12; + uint64_t nbh10; + uint64_t nbh11; + uint64_t nbh12; + uint64_t nbh13; + int32_t neh0; + int32_t neh1; + int16_t r2; + int16_t r3; } ggml_metal_kargs_mul_mm_id; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 266d8af4693..943f0fe722a 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -44,8 +44,8 @@ // note: assumes single GPU device - the default one // TODO: support multiple GPU devices static struct ggml_backend_metal_device_context { - id mtl_device; - int mtl_device_ref_count; + id mtl_device; + int mtl_device_ref_count; id mtl_library; bool has_simdgroup_reduction; @@ -306,28 +306,30 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, @@ -490,7 +492,264 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_COUNT }; +// +// ggml_metal_heap +// + +struct ggml_metal_heap { + // number of times the heap was unused + int n_unused; + + // total number of buffer allocations in this heap across all computes + int64_t n_alloc; + + // current offset in the heap - we reset this after each node in order to reuse the memory + size_t offs; + + // the currently allocated MTLBuffer objects in this heap + id obj; + + NSMutableArray * bufs; +}; + +static struct ggml_metal_heap * ggml_metal_heap_init(id device, size_t size) { + struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap)); + + MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init]; + desc.storageMode = MTLStorageModePrivate; + desc.cpuCacheMode = MTLCPUCacheModeDefaultCache; + desc.type = MTLHeapTypePlacement; + desc.size = size; + + heap->n_unused = 0; + heap->n_alloc = 0; + + heap->obj = [device newHeapWithDescriptor:desc]; + if (!heap->obj) { + GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size); + + free(heap); + + return false; + } + + [desc release]; + + heap->bufs = [[NSMutableArray alloc] init]; + + return heap; +} + +static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) { + heap->offs = 0; + + // count how many graph computes the heap ended up being unused + if ([heap->bufs count] > 0) { + heap->n_unused = 0; + } else { + heap->n_unused++; + } + + for (id buf in heap->bufs) { + [buf release]; + } + [heap->bufs removeAllObjects]; + + // tell the OS that it can reuse this memory if needed + // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc + [heap->obj setPurgeableState:MTLPurgeableStateVolatile]; +} + +static void ggml_metal_heap_free(struct ggml_metal_heap * heap) { + if (heap == nil) { + return; + } + + ggml_metal_heap_reset(heap); + + [heap->obj release]; + [heap->bufs release]; + + free(heap); +} + +@interface ggml_metal_heap_ptr : NSObject + +@property (nonatomic, assign) struct ggml_metal_heap * data; + +@end + +@implementation ggml_metal_heap_ptr +@end + +// +// ggml_metal_mem_pool +// + +struct ggml_metal_mem_pool { + id device; + + int n_heaps; // total number of heaps ever created (including those that were removed) + + NSMutableArray * heaps; + NSMutableArray * heaps_to_remove; +}; + +static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) { + struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool)); + + mem_pool->n_heaps = 0; + + mem_pool->heaps = [[NSMutableArray alloc] init]; + mem_pool->heaps_to_remove = [[NSMutableArray alloc] init]; + + return mem_pool; +} + +static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) { + GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps); + + size_t size_all = 0; + size_t size_cur = 0; + + for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { + GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data); + GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc); + GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused); + GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0); + GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]); + + if ([ptr.data->bufs count] > 0) { + size_cur += [ptr.data->obj size]; + } + size_all += [ptr.data->obj size]; + + ggml_metal_heap_free(ptr.data); + [ptr release]; + } + [mem_pool->heaps release]; + [mem_pool->heaps_to_remove release]; + + if (size_all > 0) { + GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0); + GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0); + } + + free(mem_pool); +} + +static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) { + for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) { + ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i]; + + struct ggml_metal_heap * heap = ptr.data; + ggml_metal_heap_reset(heap); + + // if the heap hasn't been used for a while, remove it + if (heap->n_unused >= 128) { + [mem_pool->heaps_to_remove addObject:@(i)]; + } + } + + if (mem_pool->heaps_to_remove.count > 0) { + // remove in reverse order + for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) { + NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue]; + ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index]; + + struct ggml_metal_heap * heap = ptr.data; + ggml_metal_heap_free(heap); + + [mem_pool->heaps removeObjectAtIndex:index]; + [ptr release]; + + if (i == 0) { + break; + } + } + + [mem_pool->heaps_to_remove removeAllObjects]; + } +} + +static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) { + for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { + ptr.data->offs = 0; + } +} + +static id ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) { + const size_t alignment = 256; + + const size_t size_aligned = GGML_PAD(size, alignment); + + // try one of the existing heaps + for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { + struct ggml_metal_heap * heap = ptr.data; + if (heap->offs + size_aligned <= [heap->obj size]) { + // if this is the first buffer in the heap for the current command buffer, tell the OS that + // it cannot free the memory used by the heap + // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc + if ([heap->bufs count] == 0) { + [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile]; + } + + id buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs]; + if (buf == nil) { + GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned); + return nil; + } + + heap->n_alloc++; + heap->offs += size_aligned; + + [heap->bufs addObject:buf]; + + return buf; + } + } + + // create a new heap that can fit this buffer + ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new]; + + struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned); + if (heap == NULL) { + GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned); + return NULL; + } + + //GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]); + + heap_ptr.data = heap; + ggml_metal_heap_reset(heap); + + [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile]; + id buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs]; + if (buf == nil) { + GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned); + return NULL; + } + + heap->n_alloc++; + heap->offs += size_aligned; + + [heap->bufs addObject:buf]; + + [mem_pool->heaps addObject:heap_ptr]; + mem_pool->n_heaps++; + + return buf; +} + +struct ggml_metal_command_buffer { + id obj; + + // each command buffer has a memory pool from which it can allocate temporary buffers during the compute + struct ggml_metal_mem_pool * mem_pool; +}; + struct ggml_backend_metal_context { + id device; id queue; dispatch_queue_t d_queue; @@ -515,7 +774,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte void (^encode_async)(size_t ith); // n_cb command buffers + 1 used by the main thread - id command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; + struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; // abort ggml_metal_graph_compute if callback returns true ggml_abort_callback abort_callback; @@ -705,9 +964,11 @@ @implementation GGMLMetalClass struct ggml_backend_metal_device_context * ctx_dev = dev->context; id device = ggml_backend_metal_device_acq(ctx_dev); + GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); - ctx->queue = [device newCommandQueue]; + ctx->device = device; + ctx->queue = [device newCommandQueue]; if (ctx->queue == nil) { GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); return NULL; @@ -768,7 +1029,10 @@ @implementation GGMLMetalClass ctx->gf = nil; ctx->encode_async = nil; for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { - ctx->command_buffers[i] = nil; + ctx->cmd_bufs[i].obj = nil; + + ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init(); + ctx->cmd_bufs[i].mem_pool->device = device; } #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) @@ -985,28 +1249,30 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, mul_mm_id_q4_0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, mul_mm_id_q4_1_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, mul_mm_id_q5_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, mul_mm_id_q6_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, mul_mm_id_iq2_xxs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, mul_mm_id_iq2_xs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, mul_mm_id_iq3_xxs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, mul_mm_id_iq3_s_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, mul_mm_id_iq2_s_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, mul_mm_id_iq1_s_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); @@ -1181,6 +1447,12 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { [ctx->queue release]; + for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { + // ctx->cmd_bufs[i].obj is auto released + + ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool); + } + dispatch_release(ctx->d_queue); free(ctx); @@ -1486,10 +1758,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex } } -static void ggml_metal_encode_node( +static bool ggml_metal_encode_node( ggml_backend_t backend, int idx, - id encoder) { + id encoder, + struct ggml_metal_mem_pool * mem_pool) { struct ggml_backend_metal_context * ctx = backend->context; struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; @@ -1505,7 +1778,7 @@ static void ggml_metal_encode_node( struct ggml_tensor * dst = node; if (ggml_is_empty(dst)) { - return; + return true; } switch (dst->op) { @@ -1516,7 +1789,7 @@ static void ggml_metal_encode_node( case GGML_OP_PERMUTE: { // noop -> next node - } return; + } return true; default: { } break; @@ -1527,6 +1800,8 @@ static void ggml_metal_encode_node( GGML_ABORT("unsupported op"); } + ggml_metal_mem_pool_clear(mem_pool); + const int64_t ne00 = src0 ? src0->ne[0] : 0; const int64_t ne01 = src0 ? src0->ne[1] : 0; const int64_t ne02 = src0 ? src0->ne[2] : 0; @@ -2173,26 +2448,76 @@ static void ggml_metal_encode_node( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_metal_kargs_soft_max args = { +// use this branch to test the ggml_metal_mem_pool functionality +#if 0 + // cpy to tmp buffer in MTLHeap + + id h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0)); + if (!h_src0) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0)); + return false; + } + + offs_src0 = 0; + + ggml_metal_kargs_cpy args_cpy = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, - /*.scale =*/ scale, - /*.max_bias =*/ max_bias, - /*.m0 =*/ m0, - /*.m1 =*/ m1, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne00, + /*.ne1 =*/ ne01, + /*.ne2 =*/ ne02, + /*.ne3 =*/ ne03, + /*.nb0 =*/ nb00, + /*.nb1 =*/ nb01, + /*.nb2 =*/ nb02, + /*.nb3 =*/ nb03, + }; + + if (src0->type == GGML_TYPE_F16) { + [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline]; + } else { + [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline]; + } + [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:h_src0 offset:0 atIndex:2]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type)); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)]; + +#else + id h_src0 = id_src0; +#endif + // softmax + + ggml_metal_kargs_soft_max args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, /*.n_head_log2 =*/ n_head_log2, }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:h_src0 offset:offs_src0 atIndex:0]; if (id_src1) { - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:h_src0 offset:offs_src0 atIndex:1]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&args length:sizeof(args) atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; @@ -2683,7 +3008,7 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { id pipeline = nil; @@ -2903,8 +3228,6 @@ static void ggml_metal_encode_node( } break; case GGML_OP_MUL_MAT_ID: { - const int n_as = src0->ne[2]; - // src2 = ids const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); @@ -2918,24 +3241,21 @@ static void ggml_metal_encode_node( GGML_ASSERT(ne03 == 1); GGML_ASSERT(ne13 == 1); + const uint32_t r2 = 1; + const uint32_t r3 = 1; + // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel // ne20 = n_used_experts - // ne21 = n_rows - const int dst_rows = ne20*ne21; - const int dst_rows_min = n_as; - const int dst_rows_max = (device.maxThreadgroupMemoryLength/2 - 8192)/4; - - // max size of the rowids array in the kernel shared buffer - //GGML_ASSERT(dst_rows <= dst_rows_max); + // ne21 = n_rows (batch size) + const int ne21_mm_id_min = 32; // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel if ([device supportsFamily:MTLGPUFamilyApple7] && ne00 % 32 == 0 && ne00 >= 64 && - //ne01 / ne02 >= 512 && // NOTE: this is based on Mixtral shapes, might need adjustments - dst_rows > dst_rows_min && - dst_rows <= dst_rows_max) { + (ne21 >= ne21_mm_id_min)) { + GGML_ASSERT(ne00 % 4 == 0); // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) @@ -2946,62 +3266,169 @@ static void ggml_metal_encode_node( default: break; } - id pipeline = nil; + const int64_t neh10 = ne10; // n_embd + const int64_t neh11 = ne21; // n_tokens + const int64_t neh12 = ne02; // n_expert - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; - default: GGML_ABORT("MUL_MAT_ID not implemented"); + const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16); + const uint64_t nbh11 = nbh10*neh10; + const uint64_t nbh12 = nbh11*neh11; + const uint64_t nbh13 = nbh12*neh12; + + const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12; + id h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1); + if (!h_src1) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1); + return false; } - ggml_metal_kargs_mul_mm_id args = { - /*.nei0 =*/ ne20, - /*.nei1 =*/ ne21, - /*.nbi1 =*/ nb21, - /*.ne00 =*/ ne00, - /*.ne02 =*/ ne02, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.ne13 =*/ ne13, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - }; + const int64_t neh0 = ne0; + const int64_t neh1 = ne21; + const int64_t neh2 = ne02; - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; + const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32); + const uint64_t nbh1 = nbh0*neh0; + const uint64_t nbh2 = nbh1*neh1; + //const uint64_t nbh3 = nbh2*neh2; + + const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2; + id h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst); + if (!h_dst) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst); + return false; + } - [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; + // tokens per expert + const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02; + id h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe); + if (!h_tpe) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe); + return false; + } + + // id map + // [n_expert_used, n_tokens] + const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21; + id h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids); + if (!h_ids) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids); + return false; + } + + { + const int nth = MIN(1024, ne10/4); + + ggml_metal_kargs_mul_mm_id_map0 args = { + ne10, + ne11, // n_expert_used (bcast) + nb11, + nb12, + neh11, // n_tokens + nbh11, + ne20, // n_expert_used + nb21, + }; + + id pipeline = nil; + + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer: h_src1 offset:0 atIndex:3]; + [encoder setBuffer: h_tpe offset:0 atIndex:4]; + [encoder setBuffer: h_ids offset:0 atIndex:5]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + + { + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16 ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16 ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16 ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break; + default: GGML_ABORT("MUL_MAT_ID not implemented"); + } + + ggml_metal_kargs_mul_mm_id args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.neh12 =*/ neh12, + /*.nbh10 =*/ nbh10, + /*.nbh11 =*/ nbh11, + /*.nbh12 =*/ nbh12, + /*.nbh13 =*/ nbh13, + /*.neh0 =*/ neh0, + /*.neh1 =*/ neh1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer: h_src1 offset:0 atIndex:2]; + [encoder setBuffer: h_tpe offset:0 atIndex:3]; + [encoder setBuffer: h_dst offset:0 atIndex:4]; + + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } + + { + GGML_ASSERT(ne0 % 4 == 0); + + const int nth = MIN(1024, ne0/4); - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + ggml_metal_kargs_mul_mm_id_map1 args = { + ne20, // n_expert_used + neh0, + neh1, + nbh1, + nbh2, + ne0, + nb1, + nb2, + }; + + id pipeline = nil; + + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer: h_dst offset:0 atIndex:1]; + [encoder setBuffer: h_ids offset:0 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } } else { id pipeline = nil; @@ -3195,7 +3622,7 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; const int64_t _ne1 = 1; - const int64_t ne123 = dst_rows; + const int64_t ne123 = ne20*ne21; if (smem > 0) { [encoder setThreadgroupMemoryLength:smem atIndex:0]; @@ -4601,6 +5028,8 @@ static void ggml_metal_encode_node( GGML_ABORT("fatal error"); } } + + return true; } static enum ggml_status ggml_metal_graph_compute( @@ -4654,25 +5083,25 @@ static enum ggml_status ggml_metal_graph_compute( } // the main thread commits the first few commands immediately - // command_buffer[n_cb] + // cmd_buf[n_cb] { - id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; - ctx->command_buffers[n_cb] = command_buffer; + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->cmd_bufs[n_cb].obj = cmd_buf; - [command_buffer enqueue]; + [cmd_buf enqueue]; ctx->encode_async(n_cb); } // prepare the rest of the command buffers asynchronously - // command_buffer[0.. n_cb) + // cmd_buf[0.. n_cb) for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; - ctx->command_buffers[cb_idx] = command_buffer; + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->cmd_bufs[cb_idx].obj = cmd_buf; // always enqueue the first two command buffers // enqueue all of the command buffers if we don't need to abort if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer enqueue]; + [cmd_buf enqueue]; } } @@ -4681,14 +5110,14 @@ static enum ggml_status ggml_metal_graph_compute( // wait for completion and check status of each command buffer // needed to detect if the device ran out-of-memory for example (#1881) { - id command_buffer = ctx->command_buffers[n_cb]; - [command_buffer waitUntilCompleted]; + id cmd_buf = ctx->cmd_bufs[n_cb].obj; + [cmd_buf waitUntilCompleted]; - MTLCommandBufferStatus status = [command_buffer status]; + MTLCommandBufferStatus status = [cmd_buf status]; if (status != MTLCommandBufferStatusCompleted) { GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); } return GGML_STATUS_FAILED; @@ -4696,20 +5125,20 @@ static enum ggml_status ggml_metal_graph_compute( } for (int i = 0; i < n_cb; ++i) { - id command_buffer = ctx->command_buffers[i]; - [command_buffer waitUntilCompleted]; + id cmd_buf = ctx->cmd_bufs[i].obj; + [cmd_buf waitUntilCompleted]; - MTLCommandBufferStatus status = [command_buffer status]; + MTLCommandBufferStatus status = [cmd_buf status]; if (status != MTLCommandBufferStatusCompleted) { GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); } return GGML_STATUS_FAILED; } - id next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil); + id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); if (!next_buffer) { continue; } @@ -5092,8 +5521,9 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { const int n_nodes_per_cb = ctx->n_nodes_per_cb; - id command_buffer = ctx->command_buffers[cb_idx]; - id encoder = [command_buffer computeCommandEncoder]; + id cmd_buf = ctx->cmd_bufs[cb_idx].obj; + + id encoder = [cmd_buf computeCommandEncoder]; int node_start = 0; int node_end = n_nodes_0; @@ -5105,22 +5535,29 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { const bool should_capture = ctx->capture_next_compute; + struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool; + ggml_metal_mem_pool_reset(mem_pool); + for (int idx = node_start; idx < node_end; ++idx) { if (should_capture) { [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; } - ggml_metal_encode_node(backend, idx, encoder); + const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool); if (should_capture) { [encoder popDebugGroup]; } + + if (!res) { + break; + } } [encoder endEncoding]; if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer commit]; + [cmd_buf commit]; } }); } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 9f4147e9397..a808e79c227 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -6336,127 +6336,219 @@ kernel void kernel_mul_mm( } } -// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids -// TODO: this kernel needs to be reimplemented from scratch for better performance -template -void kernel_mul_mm_id_impl( - int32_t ne00, - int32_t ne02, - uint64_t nb01, - uint64_t nb02, - int32_t ne11, - int32_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int32_t ne0, - int32_t ne1, - int64_t ne0ne1, - device const char * src0, - device const char * src1, - threadgroup ushort2 * rowids, - device char * dst, - threadgroup char * shmem, +template +kernel void kernel_mul_mm_id_map0( + constant ggml_metal_kargs_mul_mm_id_map0 & args, + device const char * src1, + device const char * src2, + device char * hsrc1, + device char * htpe, + device char * hids, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int ide = tgpig[0]; // expert id + + int n_all = 0; + + device int32_t * ids_i32 = (device int32_t *) (hids); + + for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens + device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21); + + for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used + if (src2_i32[i20] != ide) { + continue; + } + + device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11); + device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11); + + for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) { + hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]); + } + + if (tpitg.x == 0) { + ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all; + } + + ++n_all; + } + } + + if (tpitg.x == 0) { + device int32_t * tpe_i32 = (device int32_t *) (htpe); + tpe_i32[ide] = n_all; + } +} + +typedef decltype(kernel_mul_mm_id_map0) kernel_mul_mm_id_map0_t; + +template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0; + +template +kernel void kernel_mul_mm_id_map1( + constant ggml_metal_kargs_mul_mm_id_map1 & args, + device const char * hdst, + device const char * hids, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i20 = tgpig[0]; // used expert + const int i21 = tgpig[1]; // token + + device const int32_t * ids_i32 = (device const int32_t *) (hids); + device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2); + + const int id = ids_i32[i21*args.ne20 + i20]; + + const int ide = id / args.neh1; + const int idt = id % args.neh1; + + device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2); + + for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) { + dst_f32x4[i0] = hdst_f32x4[i0]; + } +} + +typedef decltype(kernel_mul_mm_id_map1) kernel_mul_mm_id_map1_t; + +template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1; + +template +kernel void kernel_mul_mm_id( + constant ggml_metal_kargs_mul_mm_id & args, + device const char * src0, + device const char * src1, + device const char * tpe, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiitg[[thread_index_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup half * sa = (threadgroup half *)(shmem); - threadgroup float * sb = (threadgroup float *)(shmem + 4096); + threadgroup T * sa = (threadgroup T *)(shmem); + threadgroup half * sb = (threadgroup half *)(shmem + 4096); const int r0 = tgpig.y; const int r1 = tgpig.x; + const int im = tgpig.z; + + device const int32_t * tpe_i32 = (device const int32_t *) (tpe); + + const int neh1 = tpe_i32[im]; - if (r1*BLOCK_SIZE_N >= ne1) return; + if (r1*BLOCK_SIZE_N >= neh1) { + return; + } // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; + const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; // a thread shouldn't load data outside of the matrix - short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - simdgroup_half8x8 ma[4]; - simdgroup_float8x8 mb[2]; + simdgroup_T8x8 ma[4]; + simdgroup_half8x8 mb[2]; simdgroup_float8x8 mc[8]; - for (int i = 0; i < 8; i++){ + + for (short i = 0; i < 8; i++){ mc[i] = make_filled_simdgroup_matrix(0.f); } + short il = (tiitg % THREAD_PER_ROW); - ushort offset1 = il/nl; + const int i12 = im%args.neh12; + const int i13 = im/args.neh12; - threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col]; + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const short offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1; - device const float * y = (device const float *)(src1 - + nb12 * id[1] - + nb11 * (id[0] % ne11) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + device const block_q * x = (device const block_q *)(src0 + + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + device const half * y = (device const half *)(src1 + + args.nbh13*i13 + + args.nbh12*i12 + + args.nbh11*(r1*BLOCK_SIZE_N + thread_col) + + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory - half4x4 temp_a; + T4x4 temp_a; dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + #pragma unroll(16) + for (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; } - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y); il = (il + 2 < nl) ? il + 2 : il % 2; - x = (il < 2) ? x + (2+nl-1)/nl : x; + x = (il < 2) ? x + (2 + nl - 1)/nl : x; y += BLOCK_SIZE_K; threadgroup_barrier(mem_flags::mem_threadgroup); // load matrices from threadgroup memory and conduct outer products - threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); - threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); + threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); - #pragma unroll(BLOCK_SIZE_K/8) - for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) + for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { #pragma unroll(4) - for (int i = 0; i < 4; i++) { + for (short i = 0; i < 4; i++) { simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) - for (int i = 0; i < 2; i++) { + for (short i = 0; i < 2; i++) { simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } - lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; - lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - #pragma unroll(8) - for (int i = 0; i < 8; i++){ + for (short i = 0; i < 8; i++){ simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); } + + lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE; + lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE; } } - { + if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) { + device float * C = (device float *) dst + + (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \ + (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0; + + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0); + } + } else { + // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup float * temp_str = ((threadgroup float *) shmem) \ - + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; - for (int i = 0; i < 8; i++) { - simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); } threadgroup_barrier(mem_flags::mem_threadgroup); if (sgitg == 0) { for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; - int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1; - - device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff; + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0; device float4 * D4 = (device float4 *) D; threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); @@ -6476,66 +6568,6 @@ void kernel_mul_mm_id_impl( } } -template -kernel void kernel_mul_mm_id( - constant ggml_metal_kargs_mul_mm_id & args, - device const char * src0s, - device const char * src1, - device char * dst, - device const char * ids, - threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - - const int32_t i02 = tgpig.z; - - tgpig.z = 0; - - device const char * src0 = src0s + i02*args.nb02; - - // row indices - threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192); - - // TODO: parallelize this loop - int32_t _ne1 = 0; - for (ushort ii1 = 0; ii1 < args.nei1; ii1++) { - for (ushort ii0 = 0; ii0 < args.nei0; ii0++) { - int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0]; - if (id == i02) { - if (tiitg == 0) { - rowids[_ne1] = ushort2(ii0, ii1); - } - _ne1++; - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - kernel_mul_mm_id_impl( - args.ne00, - args.ne02, - args.nb01, - args.nb02, - args.ne11, - args.ne12, - args.nb10, - args.nb11, - args.nb12, - args.ne0, - _ne1, - (int64_t)args.ne0*args.ne1, - src0, - src1, - rowids, - dst, - shmem, - tgpig, - tiitg, - sgitg); -} - #define QK_NL 16 // @@ -6576,63 +6608,64 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get // matrix-matrix multiplication // -typedef decltype(kernel_mul_mm) mat_mm_t; +typedef decltype(kernel_mul_mm) mul_mm_t; -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm; #endif -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm; // // indirect matrix-matrix multiplication // -typedef decltype(kernel_mul_mm_id) mat_mm_id_t; +typedef decltype(kernel_mul_mm_id) mul_mm_id; -template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id; #endif -template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; + // // matrix-vector multiplication diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 01f7e05b287..8a6546240f4 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2732,11 +2732,11 @@ void ggml_mul_mat_set_prec( c = ggml_mul_mat_id(ctx, as, b, ids); as -> [cols, rows, n_expert] - ids -> [n_experts_used, n_tokens] (i32) b -> [cols, n_expert_used, n_tokens] + ids -> [n_expert_used, n_tokens] (i32) c -> [rows, n_expert_used, n_tokens] - in b, n_experts_used can be broadcasted to match the n_expert_used of ids + in b, n_expert_used can be broadcasted to match the n_expert_used of ids c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids */ From 79fb43e252e0da61c278cdfedaaefce6cfb7e965 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Tue, 13 May 2025 13:10:08 +0300 Subject: [PATCH 18/21] ggml : add mrope kernel for metal (llama/13457) --- ggml/src/ggml-metal/ggml-metal-impl.h | 4 + ggml/src/ggml-metal/ggml-metal.m | 58 +++++++--- ggml/src/ggml-metal/ggml-metal.metal | 146 ++++++++++++++++++++++++++ 3 files changed, 192 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index ea8cf45863a..17eab976f3a 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -207,6 +207,10 @@ typedef struct { float attn_factor; float beta_fast; float beta_slow; + int32_t sect_0; + int32_t sect_1; + int32_t sect_2; + int32_t sect_3; } ggml_metal_kargs_rope; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 943f0fe722a..576f9581bda 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -332,6 +332,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, + GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, + GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, + GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, + GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F16, @@ -1275,6 +1279,10 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); @@ -1637,16 +1645,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); case GGML_OP_ROPE: - { - const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { - return false; - } - if (mode & GGML_ROPE_TYPE_VISION) { - return false; - } - return true; - } + return true; case GGML_OP_IM2COL: return op->src[0]->type == GGML_TYPE_F16; case GGML_OP_POOL_1D: @@ -3826,6 +3825,7 @@ static bool ggml_metal_encode_node( } break; case GGML_OP_ROPE: { + // make sure we have one or more position id(ne10) per token(ne02) GGML_ASSERT(ne10 % ne02 == 0); GGML_ASSERT(ne10 >= ne02); @@ -3852,20 +3852,42 @@ static bool ggml_metal_encode_node( memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float)); - const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + // mrope + const int sect_0 = ((const int32_t *) dst->op_params)[11]; + const int sect_1 = ((const int32_t *) dst->op_params)[12]; + const int sect_2 = ((const int32_t *) dst->op_params)[13]; + const int sect_3 = ((const int32_t *) dst->op_params)[14]; id pipeline = nil; - if (!is_neox) { + if (is_neox) { switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else if (is_mrope && !is_vision) { + GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else if (is_vision) { + GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break; default: GGML_ABORT("fatal error"); }; } else { switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; default: GGML_ABORT("fatal error"); }; } @@ -3896,6 +3918,10 @@ static bool ggml_metal_encode_node( /*.attn_factor =*/ attn_factor, /*.beta_fast =*/ beta_fast, /*.beta_slow =*/ beta_slow, + /* sect_0 =*/ sect_0, + /* sect_1 =*/ sect_1, + /* sect_2 =*/ sect_2, + /* sect_3 =*/ sect_3, }; [encoder setComputePipelineState:pipeline]; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index a808e79c227..9cfddf4503a 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2713,8 +2713,148 @@ kernel void kernel_rope_neox( } } +template +kernel void kernel_rope_multi( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; + + // mrope theta calculations + // note: the rest is the same as kernel_rope_neox + const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3; + const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1 + const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2 + const int sector = ic % sect_dims; + + float theta_base; + if (sector < args.sect_0) { + theta_base = (float) pos[i2]; + } else if (sector < sec_w01) { + theta_base = (float) pos[i2 + args.ne02]; + } else if (sector < sec_w012) { + theta_base = (float) pos[i2 + args.ne02 * 2]; + } else { + theta_base = (float) pos[i2 + args.ne02 * 3]; + } + // end of mrope + + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +template +kernel void kernel_rope_vision( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < 2*args.n_dims) { // different from kernel_rope_multi + const int ic = i0/2; + + // mrope theta calculations (only support 2 dimensions) + const int sect_dims = args.sect_0 + args.sect_1; + const int sector = ic % sect_dims; + + float p; + float theta_base; + if (sector < args.sect_1) { + p = (float) sector; + theta_base = (float) pos[i2]; + } else { + p = (float) sector - args.sect_0; + theta_base = (float) pos[i2 + args.ne02]; + } + + const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p); + // end of mrope + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.n_dims]; // different from kernel_rope_multi + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + typedef decltype(kernel_rope_norm) kernel_rope_norm_t; typedef decltype(kernel_rope_neox) kernel_rope_neox_t; +typedef decltype(kernel_rope_multi) kernel_rope_multi_t; +typedef decltype(kernel_rope_vision) kernel_rope_vision_t; template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm; template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm; @@ -2722,6 +2862,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_ template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox; template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox; +template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi; +template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi; + +template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision; +template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision; + typedef void (im2col_t)( device const float * x, device char * dst, From 89970b9aaaca3034b76443c5672599dd966e4e04 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 13 May 2025 13:10:17 +0300 Subject: [PATCH 19/21] sync : ggml ggml-ci --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 0e73fa41205..2ad2ea1c651 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -726dbd0636ae954ba3805c6e1fe6ecc69f851c5b +148b286332db1259dcd299c04047a1fd31b02713 From 69753804edf8e161397147b297bf7190d7f7654c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 13 May 2025 13:11:24 +0300 Subject: [PATCH 20/21] whisper : update to ggml-backend changes (#0) ggml-ci --- src/whisper.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/whisper.cpp b/src/whisper.cpp index bba91f200bf..ad4e7a12d71 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -597,7 +597,7 @@ static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector< auto & sched = allocr.sched; auto & meta = allocr.meta; - sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false); + sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false, true); meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead()); From bff8dc248ad954341cc0f89fa708d24275f83887 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 13 May 2025 13:20:19 +0300 Subject: [PATCH 21/21] talk-llama : sync llama.cpp ggml-ci --- examples/talk-llama/CMakeLists.txt | 1 + examples/talk-llama/llama-adapter.cpp | 6 + examples/talk-llama/llama-batch.cpp | 6 +- examples/talk-llama/llama-batch.h | 3 +- examples/talk-llama/llama-chat.cpp | 24 +- examples/talk-llama/llama-chat.h | 1 + examples/talk-llama/llama-context.cpp | 884 ++++----- examples/talk-llama/llama-context.h | 76 +- examples/talk-llama/llama-cparams.h | 1 + examples/talk-llama/llama-graph.cpp | 58 +- examples/talk-llama/llama-graph.h | 20 +- examples/talk-llama/llama-kv-cache.cpp | 1888 ++++++++++++++++---- examples/talk-llama/llama-kv-cache.h | 352 +++- examples/talk-llama/llama-memory.h | 12 +- examples/talk-llama/llama-model-loader.cpp | 24 +- examples/talk-llama/llama-model-saver.cpp | 281 +++ examples/talk-llama/llama-model-saver.h | 37 + examples/talk-llama/llama-model.cpp | 142 +- examples/talk-llama/llama-model.h | 8 +- examples/talk-llama/llama-quant.cpp | 4 +- examples/talk-llama/llama-sampling.cpp | 24 +- examples/talk-llama/llama-vocab.cpp | 46 +- examples/talk-llama/llama-vocab.h | 6 + examples/talk-llama/llama.cpp | 13 + examples/talk-llama/llama.h | 63 +- 25 files changed, 2851 insertions(+), 1129 deletions(-) create mode 100644 examples/talk-llama/llama-model-saver.cpp create mode 100644 examples/talk-llama/llama-model-saver.h diff --git a/examples/talk-llama/CMakeLists.txt b/examples/talk-llama/CMakeLists.txt index 3e3971a7ae1..e060ba7bfc8 100644 --- a/examples/talk-llama/CMakeLists.txt +++ b/examples/talk-llama/CMakeLists.txt @@ -20,6 +20,7 @@ if (WHISPER_SDL2) llama-memory.cpp llama-mmap.cpp llama-model-loader.cpp + llama-model-saver.cpp llama-model.cpp llama-quant.cpp llama-sampling.cpp diff --git a/examples/talk-llama/llama-adapter.cpp b/examples/talk-llama/llama-adapter.cpp index 7ac54d2391f..8d94034aed9 100644 --- a/examples/talk-llama/llama-adapter.cpp +++ b/examples/talk-llama/llama-adapter.cpp @@ -253,6 +253,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ std::vector buft_extra; { auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) @@ -291,6 +294,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft)); auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } buft = ggml_backend_dev_buffer_type(cpu_dev); break; diff --git a/examples/talk-llama/llama-batch.cpp b/examples/talk-llama/llama-batch.cpp index 01d5ca57fd8..a88b2fe3082 100644 --- a/examples/talk-llama/llama-batch.cpp +++ b/examples/talk-llama/llama-batch.cpp @@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { return ubatch; } -void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { +llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { GGML_ASSERT(batch.n_tokens >= 0); this->batch = &batch; this->n_embd = n_embd; @@ -203,6 +203,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim for (size_t i = 0; i < n_tokens; ++i) { ids[i] = i; } + if (simple_split) { seq.resize(1); llama_sbatch_seq & s = seq[0]; @@ -212,6 +213,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim s.length = n_tokens; return; } + std::sort(ids.begin(), ids.end(), [&batch](size_t a, size_t b) { int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; @@ -239,6 +241,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim return n_seq_a > n_seq_b; } ); + // init seq llama_sbatch_seq * last_seq = nullptr; @@ -262,6 +265,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim seq.push_back(new_seq); last_seq = &seq.back(); } + // keep shared prompts first at the end, then sort by length descending. std::sort(seq.begin(), seq.end(), [](llama_sbatch_seq & a, llama_sbatch_seq & b) { diff --git a/examples/talk-llama/llama-batch.h b/examples/talk-llama/llama-batch.h index f1df40d2708..6305051b62b 100644 --- a/examples/talk-llama/llama-batch.h +++ b/examples/talk-llama/llama-batch.h @@ -70,7 +70,8 @@ struct llama_sbatch { // sequence-wise split llama_ubatch split_seq(size_t n_ubatch); - void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); + llama_sbatch() = default; + llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); }; // temporary allocate memory for the input batch if needed diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index 735d2619c92..d12743e6b9a 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -35,6 +35,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 }, { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN }, { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 }, + { "mistral-v7-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN }, { "phi3", LLM_CHAT_TEMPLATE_PHI_3 }, { "phi4", LLM_CHAT_TEMPLATE_PHI_4 }, { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 }, @@ -202,19 +203,20 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|im_start|>assistant\n"; } - } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) { + } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN) { // Official mistral 'v7' template // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7 + // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503#basic-instruct-template-v7-tekken + const char * trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 ? " " : ""; for (auto message : chat) { std::string role(message->role); std::string content(message->content); if (role == "system") { - ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]"; + ss << "[SYSTEM_PROMPT]" << trailing_space << content << "[/SYSTEM_PROMPT]"; } else if (role == "user") { - ss << "[INST] " << content << "[/INST]"; - } - else { - ss << " " << content << ""; + ss << "[INST]" << trailing_space << content << "[/INST]"; + } else { + ss << trailing_space << content << ""; } } } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 @@ -447,8 +449,16 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|assistant|>"; } - } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4 || tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) { + } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4) { ss << "[gMASK]" << ""; + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n" << message->content; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) { for (auto message : chat) { std::string role(message->role); ss << "<|" << role << "|>" << "\n" << message->content; diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index 3f5843466d0..db24ade21e2 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -14,6 +14,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_MISTRAL_V3, LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN, LLM_CHAT_TEMPLATE_MISTRAL_V7, + LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN, LLM_CHAT_TEMPLATE_PHI_3, LLM_CHAT_TEMPLATE_PHI_4, LLM_CHAT_TEMPLATE_FALCON_3, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 5a2eef9b784..62246c10dab 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -6,11 +6,9 @@ #include "llama-model.h" #include "llama-kv-cache.h" -#include #include #include #include -#include // // llama_context @@ -95,6 +93,7 @@ llama_context::llama_context( } cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + cparams.op_offload = params.op_offload; const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; @@ -118,8 +117,6 @@ llama_context::llama_context( __func__, n_ctx_per_seq, hparams.n_ctx_train); } - logits_all = params.logits_all; - if (!hparams.vocab_only) { // GPU backends for (auto * dev : model.devices) { @@ -177,44 +174,13 @@ llama_context::llama_context( } // init the memory module - // TODO: for now, always create a unified KV cache if (!hparams.vocab_only) { - kv_self.reset(static_cast(model.create_memory())); - - LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - - cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams)); - - LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); - - uint32_t kv_size = cparams.n_ctx; - ggml_type type_k = params.type_k; - ggml_type type_v = params.type_v; - - if (llama_model_is_recurrent(&model)) { - // Mamba needs at least as many KV cells as there are sequences kept at any time - kv_size = std::max((uint32_t) 1, params.n_seq_max); - // it's probably best to keep as much precision as possible for the states - type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states - type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states - } - - GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); - GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); - - if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) { - throw std::runtime_error("failed to initialize self-attention cache"); - } + llama_memory_params params_mem = { + /*.type_k =*/ params.type_k, + /*.type_v =*/ params.type_v, + }; - { - const size_t memory_size_k = kv_self->size_k_bytes(); - const size_t memory_size_v = kv_self->size_v_bytes(); - - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), - ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); - } + memory.reset(model.create_memory(params_mem, cparams)); } // init backends @@ -278,7 +244,7 @@ llama_context::llama_context( } } - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel)); + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload)); if (pipeline_parallel) { LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); @@ -286,7 +252,7 @@ llama_context::llama_context( } // reserve worst-case graph - if (!hparams.vocab_only) { + if (!hparams.vocab_only && memory) { const uint32_t n_seqs = 1; // TODO: worst-case number of sequences const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); @@ -305,7 +271,9 @@ llama_context::llama_context( int n_nodes_tg = -1; // simulate full KV cache - kv_self->n = kv_self->size; + llama_kv_cache * kv_self = static_cast(memory.get()); + + kv_self->set_full(); cross.v_embd.clear(); @@ -391,7 +359,9 @@ llama_context::llama_context( } } -llama_context::~llama_context() = default; +llama_context::~llama_context() { + ggml_opt_free(opt_ctx); +} void llama_context::synchronize() { ggml_backend_sched_synchronize(sched.get()); @@ -427,6 +397,18 @@ const llama_model & llama_context::get_model() const { return model; } +const llama_cparams & llama_context::get_cparams() const { + return cparams; +} + +ggml_backend_sched_t llama_context::get_sched() const { + return sched.get(); +} + +ggml_context * llama_context::get_ctx_compute() const { + return ctx_compute.get(); +} + uint32_t llama_context::n_ctx() const { return cparams.n_ctx; } @@ -456,337 +438,21 @@ uint32_t llama_context::n_threads_batch() const { } llama_kv_cache * llama_context::get_kv_self() { - return kv_self.get(); + llama_kv_cache * kv_self = static_cast(memory.get()); + return kv_self; } const llama_kv_cache * llama_context::get_kv_self() const { - return kv_self.get(); -} - -ggml_tensor * llama_context::build_rope_shift( - ggml_context * ctx0, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale) const { - const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; - - const auto & yarn_ext_factor = cparams.yarn_ext_factor; - const auto & yarn_beta_fast = cparams.yarn_beta_fast; - const auto & yarn_beta_slow = cparams.yarn_beta_slow; - - const auto & hparams = model.hparams; - - const auto & n_rot = hparams.n_rot; - const auto & rope_type = hparams.rope_type; - - // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. - // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. - const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; - - ggml_tensor * tmp; - - if (ggml_is_quantized(cur->type)) { - // dequantize to f32 -> RoPE -> quantize back - tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32); - - tmp = ggml_rope_ext(ctx0, tmp, - shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); - - tmp = ggml_cpy(ctx0, tmp, cur); - } else { - // we rotate only the first n_rot dimensions - tmp = ggml_rope_ext_inplace(ctx0, cur, - shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); - } - - return tmp; -} - -class llm_graph_input_k_shift : public llm_graph_input_i { -public: - llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} - virtual ~llm_graph_input_k_shift() = default; - - void set_input(const llama_ubatch * ubatch) override; - - ggml_tensor * k_shift; // I32 [kv_size] - - const llama_kv_cache_unified * kv_self; -}; - -void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { - GGML_UNUSED(ubatch); - - if (k_shift) { - assert(ggml_backend_buffer_is_host(k_shift->buffer)); - - int32_t * data = (int32_t *) k_shift->data; - - for (uint32_t i = 0; i < kv_self->size; ++i) { - data[i] = kv_self->cells[i].delta; - } - } -} - -llm_graph_result_ptr llama_context::build_kv_self_shift( - ggml_context * ctx0, - ggml_cgraph * gf) const { - auto res = std::make_unique(); - - const auto & hparams = model.hparams; - - const auto & n_layer = hparams.n_layer; - - const auto & n_embd_head_k = hparams.n_embd_head_k; - //const auto & n_embd_head_v = hparams.n_embd_head_v; - - //GGML_ASSERT(kv_self->size == n_ctx); - - auto inp = std::make_unique(kv_self.get()); - - inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx); - ggml_set_input(inp->k_shift); - - for (uint32_t il = 0; il < n_layer; ++il) { - const int64_t n_head_kv = hparams.n_head_kv(il); - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - - const bool is_swa = hparams.is_swa(il); - - // note: the swa rope params could become part of the cparams in the future - // if we decide to make them configurable, like the non-sliding ones - const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; - const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; - - ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il); - - ggml_tensor * k = - ggml_view_3d(ctx0, kv_self->k_l[il], - n_embd_head_k, n_head_kv, kv_self->size, - ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k), - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), - 0); - - ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); - - ggml_build_forward_expand(gf, cur); - } - - res->add_input(std::move(inp)); - - return res; -} - -llm_graph_result_ptr llama_context::build_kv_self_defrag( - ggml_context * ctx0, - ggml_cgraph * gf) const { - auto res = std::make_unique(); - - const auto & hparams = model.hparams; - - const auto & ids = kv_self->defrag_info.ids; - -#if 0 - // CPU defrag - // - // TODO: optimizations are possible: - // - multiple threads - // - avoid copying to the host memory when already there - // - // likely not worth the effort, as we have ggml_graph based defrag - // - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - - const uint32_t kv_size = size; - - std::vector buf_k; - std::vector buf_v; - - for (uint32_t il = 0; il < n_layer; ++il) { - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); - - const size_t v_size_el = ggml_type_size(v_l[il]->type); - const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); - - buf_k.resize(k_size); - buf_v.resize(v_size); - - ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); - - // batch move [i, i+nm) to [id, id+nm) - // note: cells can move only to a lower index - for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == n_kv) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < n_kv && ids[i + nm] == id + nm) { - nm++; - } - - // move keys - { - const int64_t os = i*k_size_row; - const int64_t od = id*k_size_row; - - memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); - } - - // move values (note: they are transposed) - { - const int64_t os = i; - const int64_t od = id; - - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); - } - } - - i += nm - 1; - } - - ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); - } -#else - for (uint32_t i = 0; i < ids.size(); ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == ids.size()) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < ids.size() && ids[i + nm] == id + nm) { - nm++; - } - - for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - - ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il], - n_embd_k_gqa, nm, - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i)); - - ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il], - n_embd_k_gqa, nm, - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id)); - - ggml_tensor * view_v_src; - ggml_tensor * view_v_dst; - - if (cparams.flash_attn) { - // NOTE: the V cache is not transposed when using flash attention - view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il], - n_embd_v_gqa, nm, - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i)); - - view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il], - n_embd_v_gqa, nm, - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id)); - } else { - view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self->v_l[il]->type, kv_self->size), - ggml_row_size(kv_self->v_l[il]->type, i)); - - view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self->v_l[il]->type, kv_self->size), - ggml_row_size(kv_self->v_l[il]->type, id)); - } - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst)); - } - - i += nm - 1; - } - - //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); -#endif - - return res; + llama_kv_cache * kv_self = static_cast(memory.get()); + return kv_self; } void llama_context::kv_self_update() { - auto & kv = kv_self; - bool need_reserve = false; - if (kv->has_shift) { - if (!kv->get_can_shift()) { - GGML_ABORT("The current context does not support K-shift"); - } - - LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); - - // apply K-shift if needed - if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { - ggml_backend_sched_reset(sched.get()); - - auto * gf = graph_init(); - - auto res = build_kv_self_shift(ctx_compute.get(), gf); - - ggml_backend_sched_alloc_graph(sched.get(), gf); - - res->set_inputs(nullptr); - - graph_compute(gf, false); + llama_kv_cache * kv_self = static_cast(memory.get()); - need_reserve = true; - } - - { - kv->has_shift = false; - - for (uint32_t i = 0; i < kv->size; ++i) { - kv->cells[i].delta = 0; - } - } - } - - // defragment the KV cache if needed - if (kv->do_defrag) { - LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - - if (kv->defrag_prepare(graph_max_nodes())) { - ggml_backend_sched_reset(sched.get()); - - auto * gf = graph_init(); - - auto res = build_kv_self_defrag(ctx_compute.get(), gf); - - ggml_backend_sched_alloc_graph(sched.get(), gf); - - res->set_inputs(nullptr); - - graph_compute(gf, false); - - need_reserve = true; - } - - kv->do_defrag = false; - } + need_reserve = kv_self->update(*this); // reserve a worst case graph if needed if (need_reserve) { @@ -797,7 +463,7 @@ void llama_context::kv_self_update() { uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); // simulate full KV cache - kv_self->n = kv_self->size; + kv_self->set_full(); llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; @@ -818,9 +484,6 @@ enum llama_pooling_type llama_context::pooling_type() const { } float * llama_context::get_logits() { - // reorder logits for backward compatibility - output_reorder(); - return logits; } @@ -863,9 +526,6 @@ float * llama_context::get_logits_ith(int32_t i) { } float * llama_context::get_embeddings() { - // reorder embeddings for backward compatibility - output_reorder(); - return embd; } @@ -1017,8 +677,8 @@ int llama_context::encode(llama_batch & inp_batch) { } // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); + // note: during encode, we always pass the full sequence starting from pos = 0 + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0); const llama_batch & batch = batch_allocr.batch; const int32_t n_tokens = batch.n_tokens; @@ -1043,11 +703,13 @@ int llama_context::encode(llama_batch & inp_batch) { t_compute_start_us = ggml_time_us(); } + embd_seq.clear(); + n_queued_tokens += n_tokens; const int64_t n_embd = hparams.n_embd; - sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); + llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); const llama_ubatch ubatch = sbatch.split_simple(n_tokens); @@ -1104,12 +766,12 @@ int llama_context::encode(llama_batch & inp_batch) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); - GGML_ASSERT(embd != nullptr); - switch (cparams.pooling_type) { case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings + GGML_ASSERT(embd != nullptr); + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float)); } break; @@ -1134,11 +796,18 @@ int llama_context::encode(llama_batch & inp_batch) { } break; case LLAMA_POOLING_TYPE_RANK: { - // TODO: this likely should be the same logic as in llama_decoder_internal, but better to - // wait for an encoder model that requires this pooling type in order to test it - // https://github.com/ggerganov/llama.cpp/pull/9510 - GGML_ABORT("RANK pooling not implemented yet"); - } + // extract the rerank score - a single float per sequence + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(1); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float)); + } + } break; case LLAMA_POOLING_TYPE_UNSPECIFIED: { GGML_ABORT("unknown pooling type"); @@ -1176,14 +845,21 @@ int llama_context::encode(llama_batch & inp_batch) { } int llama_context::decode(llama_batch & inp_batch) { + if (!memory) { + LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__); + return encode(inp_batch); + } + if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; } + llama_kv_cache * kv_self = static_cast(memory.get()); + // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); + // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1); const llama_batch & batch = batch_allocr.batch; @@ -1195,7 +871,7 @@ int llama_context::decode(llama_batch & inp_batch) { const int64_t n_tokens_all = batch.n_tokens; const int64_t n_embd = hparams.n_embd; - llama_kv_cache_guard kv_guard(kv_self.get()); + llama_kv_cache_guard kv_guard(kv_self); GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT @@ -1229,18 +905,14 @@ int llama_context::decode(llama_batch & inp_batch) { for (uint32_t i = 0; i < n_tokens_all; ++i) { n_outputs_all += batch.logits[i] != 0; } - } else if (logits_all || embd_pooled) { + } else if (embd_pooled) { n_outputs_all = n_tokens_all; } else { // keep last output only n_outputs_all = 1; } - const bool logits_all = n_outputs_all == n_tokens_all; - - sbatch.from_batch(batch, n_embd, - /* simple_split */ !kv_self->recurrent, - /* logits_all */ logits_all); + llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all); // reserve output buffer if (output_reserve(n_outputs_all) < n_outputs_all) { @@ -1254,22 +926,7 @@ int llama_context::decode(llama_batch & inp_batch) { int64_t n_outputs_prev = 0; while (sbatch.n_tokens > 0) { - llama_ubatch ubatch = llama_ubatch(); - - const auto & n_ubatch = cparams.n_ubatch; - - if (kv_self->recurrent) { - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - ubatch = sbatch.split_seq(cparams.n_ubatch); - } else { - // recurrent model architectures are easier to implement - // with equal-length sequences - ubatch = sbatch.split_equal(cparams.n_ubatch); - } - } else { - ubatch = sbatch.split_simple(n_ubatch); - } + llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled); // count the outputs in this u_batch { @@ -1289,24 +946,12 @@ int llama_context::decode(llama_batch & inp_batch) { } // find KV slot - { - if (!kv_self->find_slot(ubatch)) { - LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); - - return 1; - } + if (!kv_self->find_slot(ubatch)) { + LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); - if (!kv_self->recurrent) { - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - const uint32_t pad = kv_self->get_padding(cparams); - kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad))); - } + return 1; } - //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head); - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); @@ -1420,43 +1065,68 @@ int llama_context::decode(llama_batch & inp_batch) { // finalize the batch processing kv_guard.commit(); + // set to total number of outputs in the batch, for use in llama_get_logits_ith + n_outputs = n_outputs_all; + // set output mappings { bool sorted_output = true; - GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all); + auto & out_ids = sbatch.out_ids; + + GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all); for (int64_t i = 0; i < n_outputs_all; ++i) { - int64_t out_id = sbatch.out_ids[i]; + int64_t out_id = out_ids[i]; output_ids[out_id] = i; if (out_id != i) { sorted_output = false; } } - if (sorted_output) { - sbatch.out_ids.clear(); + // make the outputs have the same order they had in the user-provided batch + // note: this is mostly relevant for recurrent models atm + if (!sorted_output) { + const uint32_t n_vocab = model.vocab.n_tokens(); + const uint32_t n_embd = model.hparams.n_embd; + + GGML_ASSERT((size_t) n_outputs == out_ids.size()); + + // TODO: is there something more efficient which also minimizes swaps? + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) + for (int32_t i = 0; i < n_outputs - 1; ++i) { + int32_t j_min = i; + for (int32_t j = i + 1; j < n_outputs; ++j) { + if (out_ids[j] < out_ids[j_min]) { + j_min = j; + } + } + if (j_min == i) { continue; } + std::swap(out_ids[i], out_ids[j_min]); + if (logits_size > 0) { + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); + } + } + if (embd_size > 0) { + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); + } + } + } + std::fill(output_ids.begin(), output_ids.end(), -1); + for (int32_t i = 0; i < n_outputs; ++i) { + output_ids[out_ids[i]] = i; + } } } - // set to total number of outputs in the batch, for use in llama_get_logits_ith - n_outputs = n_outputs_all; - // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); // decide if we need to defrag the kv cache - if (cparams.causal_attn && cparams.defrag_thold > 0.0f) { - // - do not defrag small contexts (i.e. < 2048 tokens) - // - count the padding towards the number of used tokens - const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f; - - // queue defragmentation for next llama_kv_cache_update - if (fragmentation > cparams.defrag_thold) { - LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); - - kv_self->defrag(); - } + if (cparams.defrag_thold > 0.0f) { + kv_self->defrag_sched(cparams.defrag_thold); } // Reset state for the next token before backend sync, to allow the CPU activities in the reset to @@ -1542,44 +1212,6 @@ int32_t llama_context::output_reserve(int32_t n_outputs) { return n_outputs_max; } -void llama_context::output_reorder() { - auto & out_ids = sbatch.out_ids; - if (!out_ids.empty()) { - const uint32_t n_vocab = model.vocab.n_tokens(); - const uint32_t n_embd = model.hparams.n_embd; - - GGML_ASSERT((size_t) n_outputs == out_ids.size()); - - // TODO: is there something more efficient which also minimizes swaps? - // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) - for (int32_t i = 0; i < n_outputs - 1; ++i) { - int32_t j_min = i; - for (int32_t j = i + 1; j < n_outputs; ++j) { - if (out_ids[j] < out_ids[j_min]) { - j_min = j; - } - } - if (j_min == i) { continue; } - std::swap(out_ids[i], out_ids[j_min]); - if (logits_size > 0) { - for (uint32_t k = 0; k < n_vocab; k++) { - std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); - } - } - if (embd_size > 0) { - for (uint32_t k = 0; k < n_embd; k++) { - std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); - } - } - } - std::fill(output_ids.begin(), output_ids.end(), -1); - for (int32_t i = 0; i < n_outputs; ++i) { - output_ids[out_ids[i]] = i; - } - out_ids.clear(); - } -} - // // graph // @@ -1616,7 +1248,7 @@ llm_graph_result_ptr llama_context::graph_build( /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ &cvec, /*.loras =*/ &loras, - /*.memory =*/ kv_self.get(), + /*.memory =*/ memory.get(), /*.cross =*/ &cross, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), @@ -2020,8 +1652,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { { LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__); - output_reorder(); - const auto n_outputs = this->n_outputs; const auto & output_ids = this->output_ids; @@ -2075,6 +1705,8 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { } LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); + llama_kv_cache * kv_self = static_cast(memory.get()); + kv_self->state_write(io); return io.n_bytes(); @@ -2158,8 +1790,13 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { } } - LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); - kv_self->state_read(io); + if (memory) { + LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); + + llama_kv_cache * kv_self = static_cast(memory.get()); + + kv_self->state_read(io); + } return io.n_bytes(); } @@ -2167,7 +1804,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { GGML_UNUSED(seq_id); - kv_self->state_write(io, seq_id); + if (memory) { + llama_kv_cache * kv_self = static_cast(memory.get()); + + kv_self->state_write(io, seq_id); + } return io.n_bytes(); } @@ -2175,7 +1816,11 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { GGML_UNUSED(seq_id); - kv_self->state_read(io, seq_id); + if (memory) { + llama_kv_cache * kv_self = static_cast(memory.get()); + + kv_self->state_read(io, seq_id); + } return io.n_bytes(); } @@ -2203,6 +1848,215 @@ void llama_context::perf_reset() { t_p_eval_us = n_p_eval = 0; } +// +// training +// + +static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) { + if (!tensor || tensor->type != GGML_TYPE_F32) { + return; + } + if (!param_filter(tensor, userdata)) { + return; + } + if (strcmp(tensor->name, "token_embd.weight") == 0) { + return; // FIXME + } + if (strcmp(tensor->name, "rope_freqs.weight") == 0) { + return; // FIXME + } + ggml_set_param(tensor); +} + +void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) { + GGML_ASSERT(!opt_ctx); + model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx(); + const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train); + const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); + GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0); + GGML_ASSERT(n_batch % n_ubatch == 0); + + ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY); + opt_params.opt_period = n_batch / n_ubatch; + opt_params.get_opt_pars = lopt_params.get_opt_pars; + opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud; + + opt_ctx = ggml_opt_init(opt_params); + + llama_opt_param_filter param_filter = lopt_params.param_filter; + void * param_filter_ud = lopt_params.param_filter_ud; + + //llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME + llama_set_param(model->type_embd, param_filter, param_filter_ud); + llama_set_param(model->pos_embd, param_filter, param_filter_ud); + llama_set_param(model->tok_norm, param_filter, param_filter_ud); + llama_set_param(model->tok_norm_b, param_filter, param_filter_ud); + llama_set_param(model->output_norm, param_filter, param_filter_ud); + llama_set_param(model->output_norm_b, param_filter, param_filter_ud); + llama_set_param(model->output, param_filter, param_filter_ud); + llama_set_param(model->output_b, param_filter, param_filter_ud); + llama_set_param(model->output_norm_enc, param_filter, param_filter_ud); + llama_set_param(model->cls, param_filter, param_filter_ud); + llama_set_param(model->cls_b, param_filter, param_filter_ud); + llama_set_param(model->cls_out, param_filter, param_filter_ud); + llama_set_param(model->cls_out_b, param_filter, param_filter_ud); + + for (struct llama_layer & layer : model->layers) { + for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { + llama_set_param(reinterpret_cast(&layer)[i], param_filter, param_filter_ud); + } + } +} + +void llama_context::opt_epoch_iter( + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, + const std::vector & tokens, + const std::vector & labels_sparse, + llama_batch & batch, + ggml_opt_epoch_callback callback, + bool train, + int64_t idata_in_loop, + int64_t ndata_in_loop, + int64_t t_loop_start) { + GGML_ASSERT(opt_ctx); + const uint32_t n_ctx = llama_model_n_ctx_train(&model); + const uint32_t n_batch = std::min(this->n_batch(), n_ctx); + const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); + + llama_kv_cache * kv_self = static_cast(memory.get()); + + kv_self->clear(); + llama_kv_cache_guard kv_guard(kv_self); + + for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { + batch.n_tokens = n_batch; + for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) { + batch.token [pos_batch] = tokens[pos_ctx + pos_batch]; + batch.pos [pos_batch] = pos_ctx + pos_batch; + batch.n_seq_id[pos_batch] = 1; + batch.seq_id [pos_batch][0] = 0; + batch.logits [pos_batch] = true; + } + + const auto n_tokens_all = batch.n_tokens; + + n_queued_tokens += n_tokens_all; + + // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens + const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + + embd_seq.clear(); + + int64_t n_outputs_all = n_tokens_all; + + llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true); + + // reserve output buffer + if (output_reserve(n_outputs_all) < n_outputs_all) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all); + GGML_ABORT("TODO: handle this error"); + }; + + for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) { + llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled); + + n_outputs = ubatch.n_tokens; + + // TODO: not sure if this is needed + if (!kv_self->find_slot(ubatch)) { + LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); + + GGML_ABORT("TODO: handle this error"); + } + + auto * gf = graph_init(); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT); + + struct ggml_context * ctx_compute_opt; + { + const size_t size_gf = ggml_graph_size(gf); + const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true); + struct ggml_init_params params = { + /*.mem_size =*/ size_meta, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ctx_compute_opt = ggml_init(params); + } + ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); + ggml_opt_alloc(opt_ctx, train); + res->set_inputs(&ubatch); + { + struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); + GGML_ASSERT(labels->ne[1] == n_ubatch); + ggml_set_zero(labels); + const float onef = 1.0f; + for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) { + const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch; + GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]); + ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float)); + } + } + ggml_opt_eval(opt_ctx, result); + if (callback) { + callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start); + } + ggml_free(ctx_compute_opt); + } + } + + kv_guard.commit(); +} + +void llama_context::opt_epoch( + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval) { + const uint32_t n_ctx = this->n_ctx(); + const uint32_t n_batch = std::min(cparams.n_batch, n_ctx); + const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch); + const int64_t ndata = ggml_opt_dataset_ndata(dataset); + + GGML_ASSERT(idata_split >= 0); + GGML_ASSERT(idata_split <= ndata); + + const uint32_t ubatch_per_ctx = n_ctx / n_ubatch; + + struct llama_batch batch = llama_batch_init(n_batch, 0, 1); + std::vector tokens(n_ctx); + std::vector labels_sparse(n_ctx); + + int64_t idata = 0; + + int64_t t_loop_start = ggml_time_us(); + int64_t ndata_in_loop = idata_split*ubatch_per_ctx; + for (; idata < idata_split; ++idata) { + constexpr bool train = true; + const int64_t idata_in_loop = idata*ubatch_per_ctx; + + ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); + opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch, + callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start); + } + + t_loop_start = ggml_time_us(); + ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx; + for (; idata < ndata; ++idata) { + constexpr bool train = false; + const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx; + + ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); + opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch, + callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start); + } + + llama_batch_free(batch); +} + // // interface implementation // @@ -2230,13 +2084,13 @@ llama_context_params llama_context_default_params() { /*.cb_eval_user_data =*/ nullptr, /*.type_k =*/ GGML_TYPE_F16, /*.type_v =*/ GGML_TYPE_F16, - /*.logits_all =*/ false, + /*.abort_callback =*/ nullptr, + /*.abort_callback_data =*/ nullptr, /*.embeddings =*/ false, /*.offload_kqv =*/ true, /*.flash_attn =*/ false, /*.no_perf =*/ true, - /*.abort_callback =*/ nullptr, - /*.abort_callback_data =*/ nullptr, + /*.op_offload =*/ true, }; return result; @@ -2530,7 +2384,7 @@ void llama_kv_cache_seq_cp( llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); + llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); } void llama_kv_self_seq_cp( @@ -2544,14 +2398,14 @@ void llama_kv_self_seq_cp( return; } - return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); } // deprecated void llama_kv_cache_seq_keep( llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_self_seq_keep(ctx, seq_id); + llama_kv_self_seq_keep(ctx, seq_id); } void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { @@ -2560,7 +2414,7 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { return; } - return kv->seq_keep(seq_id); + kv->seq_keep(seq_id); } // deprecated @@ -2570,7 +2424,7 @@ void llama_kv_cache_seq_add( llama_pos p0, llama_pos p1, llama_pos delta) { - return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta); + llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta); } void llama_kv_self_seq_add( @@ -2584,7 +2438,7 @@ void llama_kv_self_seq_add( return; } - return kv->seq_add(seq_id, p0, p1, delta); + kv->seq_add(seq_id, p0, p1, delta); } // deprecated @@ -2594,7 +2448,7 @@ void llama_kv_cache_seq_div( llama_pos p0, llama_pos p1, int d) { - return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d); + llama_kv_self_seq_div(ctx, seq_id, p0, p1, d); } void llama_kv_self_seq_div( @@ -2608,7 +2462,7 @@ void llama_kv_self_seq_div( return; } - return kv->seq_div(seq_id, p0, p1, d); + kv->seq_div(seq_id, p0, p1, d); } // deprecated @@ -2627,7 +2481,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { // deprecated void llama_kv_cache_defrag(llama_context * ctx) { - return llama_kv_self_defrag(ctx); + llama_kv_self_defrag(ctx); } void llama_kv_self_defrag(llama_context * ctx) { @@ -2636,7 +2490,8 @@ void llama_kv_self_defrag(llama_context * ctx) { return; } - return kv->defrag(); + // force defrag + kv->defrag_sched(-1.0f); } // deprecated @@ -2820,3 +2675,34 @@ void llama_perf_context_print(const llama_context * ctx) { void llama_perf_context_reset(llama_context * ctx) { ctx->perf_reset(); } + +// +// training +// + +bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) { + GGML_UNUSED(tensor); + GGML_UNUSED(userdata); + return true; +} + +void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) { + ctx->opt_init(model, lopt_params); +} + +void llama_opt_epoch( + struct llama_context * ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval) { + ctx->opt_epoch( + dataset, + result_train, + result_eval, + idata_split, + callback_train, + callback_eval); +} diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index 5457f077c15..c0ceacb10ce 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -7,6 +7,7 @@ #include "llama-adapter.h" #include "ggml-cpp.h" +#include "ggml-opt.h" #include #include @@ -27,7 +28,12 @@ struct llama_context { void synchronize(); - const llama_model & get_model() const; + const llama_model & get_model() const; + const llama_cparams & get_cparams() const; + + ggml_backend_sched_t get_sched() const; + + ggml_context * get_ctx_compute() const; uint32_t n_ctx() const; uint32_t n_ctx_per_seq() const; @@ -128,6 +134,32 @@ struct llama_context { llama_perf_context_data perf_get_data() const; void perf_reset(); + // + // training + // + + void opt_init(struct llama_model * model, struct llama_opt_params lopt_params); + + void opt_epoch( + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval); + + void opt_epoch_iter( + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, + const std::vector & tokens, + const std::vector & labels_sparse, + llama_batch & batch, + ggml_opt_epoch_callback callback, + bool train, + int64_t idata_in_loop, + int64_t ndata_in_loop, + int64_t t_loop_start); + private: // // output @@ -137,49 +169,30 @@ struct llama_context { // Returns max number of outputs for which space was reserved. int32_t output_reserve(int32_t n_outputs); - // make the outputs have the same order they had in the user-provided batch - // TODO: maybe remove this - void output_reorder(); - // // graph // +public: int32_t graph_max_nodes() const; // zero-out inputs and create the ctx_compute for the compute graph ggml_cgraph * graph_init(); + // returns the result of ggml_backend_sched_graph_compute_async execution + ggml_status graph_compute( + ggml_cgraph * gf, + bool batched); + +private: llm_graph_result_ptr graph_build( ggml_context * ctx, ggml_cgraph * gf, const llama_ubatch & ubatch, llm_graph_type gtype); - // returns the result of ggml_backend_sched_graph_compute_async execution - ggml_status graph_compute( - ggml_cgraph * gf, - bool batched); - llm_graph_cb graph_get_cb() const; - // used by kv_self_update() - ggml_tensor * build_rope_shift( - ggml_context * ctx0, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale) const; - - llm_graph_result_ptr build_kv_self_shift( - ggml_context * ctx0, - ggml_cgraph * gf) const; - - llm_graph_result_ptr build_kv_self_defrag( - ggml_context * ctx0, - ggml_cgraph * gf) const; - // TODO: read/write lora adapters and cvec size_t state_write_data(llama_io_write_i & io); size_t state_read_data (llama_io_read_i & io); @@ -196,14 +209,10 @@ struct llama_context { llama_cparams cparams; llama_adapter_cvec cvec; llama_adapter_loras loras; - llama_sbatch sbatch; llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably - std::unique_ptr kv_self; - - // TODO: remove - bool logits_all = false; + std::unique_ptr memory; // decode output (2-dimensional array: [n_outputs][n_vocab]) size_t logits_size = 0; // capacity (of floats) for logits @@ -230,6 +239,9 @@ struct llama_context { ggml_context_ptr ctx_compute; + // training + ggml_opt_context_t opt_ctx = nullptr; + ggml_threadpool_t threadpool = nullptr; ggml_threadpool_t threadpool_batch = nullptr; diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index 30e550f023a..246fa5777de 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -30,6 +30,7 @@ struct llama_cparams { bool flash_attn; bool no_perf; bool warmup; + bool op_offload; enum llama_pooling_type pooling_type; diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index fabb9ca2376..b0e3f63597a 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -284,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self->head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - llama_kv_cell & kv_cell = const_cast(kv_self)->cells[i]; - - // prevent out-of-bound sources - if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) { - kv_cell.src = cell_id; - } - - data[i] = kv_cell.src; - - // TODO: do not mutate the KV cache - // ensure copy only happens once - if (kv_cell.src != (int32_t) cell_id) { - kv_cell.src = cell_id; - } + data[i] = kv_self->s_copy(i); } } } @@ -317,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { // clear unused states for (int i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self->head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - llama_kv_cell & kv_cell = const_cast(kv_self)->cells[i]; - - data[i] = (float) (kv_cell.src >= 0); - - // only clear once - if (kv_cell.src < 0) { - kv_cell.src = cell_id; - } + data[i] = kv_self->s_mask(i); } } } @@ -810,7 +782,7 @@ ggml_tensor * llm_graph_context::build_ffn( } break; } - if (type_gate == LLM_FFN_PAR) { + if (gate && type_gate == LLM_FFN_PAR) { cur = ggml_mul(ctx0, cur, tmp); cb(cur, "ffn_gate_par", il); } @@ -999,6 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); //cb(inp->tokens, "inp_tokens", -1); ggml_set_input(inp->tokens); + res->t_tokens = inp->tokens; cur = ggml_get_rows(ctx0, tok_embd, inp->tokens); @@ -1105,7 +1078,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const { } ggml_tensor * llm_graph_context::build_inp_s_copy() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); auto inp = std::make_unique(kv_self); @@ -1122,7 +1095,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const { } ggml_tensor * llm_graph_context::build_inp_s_mask() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); auto inp = std::make_unique(kv_self); @@ -1255,8 +1228,19 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); if (v_mla) { +#if 0 + // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens. + // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient. cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens); cur = ggml_mul_mat(ctx0, v_mla, cur); +#else + // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1. + // The permutations are noops and only change how the tensor data is interpreted. + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_mul_mat(ctx0, v_mla, cur); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs. +#endif } cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); @@ -1436,8 +1420,6 @@ ggml_tensor * llm_graph_context::build_attn( // store to KV cache { - GGML_ASSERT(!kv_self->recurrent); - const auto kv_head = kv_self->head; GGML_ASSERT(kv_self->size == n_ctx); @@ -1587,7 +1569,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state( ggml_tensor * state_mask, int32_t n_state, int32_t n_seqs) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto n_kv = kv_self->n; const auto kv_head = kv_self->head; @@ -1619,7 +1601,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto token_shift_count = hparams.token_shift_count; @@ -1640,7 +1622,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( ggml_tensor * token_shift, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto token_shift_count = hparams.token_shift_count; const auto n_embd = hparams.n_embd; diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index d0c8d321927..832a8c09f2b 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -19,6 +19,7 @@ struct llama_cparams; class llama_memory_i; class llama_kv_cache_unified; +class llama_kv_cache_recurrent; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -186,26 +187,26 @@ class llm_graph_input_cls : public llm_graph_input_i { class llm_graph_input_s_copy : public llm_graph_input_i { public: - llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {} virtual ~llm_graph_input_s_copy() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * s_copy; // I32 [kv_size] - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_recurrent * kv_self; }; class llm_graph_input_s_mask : public llm_graph_input_i { public: - llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {} virtual ~llm_graph_input_s_mask() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * s_mask; // F32 [1, n_kv] - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_recurrent * kv_self; }; class llm_graph_input_cross_embd : public llm_graph_input_i { @@ -297,6 +298,7 @@ class llm_graph_result_i { public: virtual ~llm_graph_result_i() = default; + virtual ggml_tensor * get_tokens() = 0; virtual ggml_tensor * get_logits() = 0; virtual ggml_tensor * get_embd() = 0; virtual ggml_tensor * get_embd_pooled() = 0; @@ -311,6 +313,7 @@ class llm_graph_result : public llm_graph_result_i { public: virtual ~llm_graph_result() = default; + ggml_tensor * get_tokens() override { return t_tokens; } ggml_tensor * get_logits() override { return t_logits; } ggml_tensor * get_embd() override { return t_embd; } ggml_tensor * get_embd_pooled() override { return t_embd_pooled; } @@ -327,6 +330,7 @@ class llm_graph_result : public llm_graph_result_i { } // important graph nodes + ggml_tensor * t_tokens = nullptr; ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; @@ -350,8 +354,8 @@ struct llm_graph_params { const llama_cparams & cparams; const llama_ubatch & ubatch; - ggml_backend_sched * sched; - ggml_backend * backend_cpu; + ggml_backend_sched_t sched; + ggml_backend_t backend_cpu; const llama_adapter_cvec * cvec; const llama_adapter_loras * loras; @@ -402,9 +406,9 @@ struct llm_graph_context { ggml_context * ctx0 = nullptr; - ggml_backend_sched * sched; + ggml_backend_sched_t sched; - ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? + ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? const llama_adapter_cvec * cvec; const llama_adapter_loras * loras; diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index 7c9d46d8119..3dcad65bb6a 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -4,33 +4,41 @@ #include "llama-batch.h" #include "llama-cparams.h" #include "llama-model.h" +#include "llama-context.h" #include #include +#include #include #include #include -llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) { +// +// llama_kv_cache_unified +// + +uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { + // the FA kernels require padding to avoid extra runtime boundary checks + return cparams.flash_attn ? 256u : 32u; } -bool llama_kv_cache_unified::init( +llama_kv_cache_unified::llama_kv_cache_unified( const llama_model & model, - const llama_cparams & cparams, ggml_type type_k, ggml_type type_v, + bool v_trans, + bool offload, uint32_t kv_size, - bool offload) { + uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) { const int32_t n_layer = hparams.n_layer; has_shift = false; + can_shift = true; - recurrent = llama_model_is_recurrent(&model); - v_trans = !recurrent && !cparams.flash_attn; - can_shift = !recurrent; + LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n", + __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift, padding); - LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", - __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift); + GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding"); head = 0; size = kv_size; @@ -76,23 +84,20 @@ bool llama_kv_cache_unified::init( const char * dev_name = "CPU"; - ggml_backend_buffer_type_t buft; + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + if (offload) { auto * dev = model.dev_layer(i); buft = ggml_backend_dev_buffer_type(dev); dev_name = ggml_backend_dev_name(dev); - } else { - buft = ggml_backend_cpu_buffer_type(); } - LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__, - i, n_embd_k_gqa, n_embd_v_gqa, dev_name); + LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, i, dev_name); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { - LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__); - return false; + throw std::runtime_error("failed to create ggml context for kv cache"); } ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); @@ -110,55 +115,28 @@ bool llama_kv_cache_unified::init( ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); if (!buf) { - LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__); - return false; + throw std::runtime_error("failed to allocate buffer for kv cache"); } ggml_backend_buffer_clear(buf, 0); LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); bufs.emplace_back(buf); } - return true; -} - -int32_t llama_kv_cache_unified::get_n_tokens() const { - int32_t result = 0; - - for (uint32_t i = 0; i < size; i++) { - result += cells[i].seq_id.size(); - } - - return result; -} - -int32_t llama_kv_cache_unified::get_used_cells() const { - return used; -} - -size_t llama_kv_cache_unified::total_size() const { - size_t size = 0; - for (const auto & buf : bufs) { - size += ggml_backend_buffer_get_size(buf.get()); - } - - return size; -} + { + const size_t memory_size_k = size_k_bytes(); + const size_t memory_size_v = size_v_bytes(); -llama_pos llama_kv_cache_unified::pos_max() const { - llama_pos pos_max = -1; - for (const auto & cell : cells) { - pos_max = std::max(pos_max, cell.pos); + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } - - return pos_max; } void llama_kv_cache_unified::clear() { for (int32_t i = 0; i < (int32_t) size; ++i) { cells[i].pos = -1; cells[i].seq_id.clear(); - cells[i].src = -1; - cells[i].tail = -1; } head = 0; used = 0; @@ -179,35 +157,6 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1 = std::numeric_limits::max(); } - // models like Mamba or RWKV can't have a state partially erased - if (recurrent) { - if (seq_id >= (int64_t) size) { - // could be fatal - return false; - } - if (0 <= seq_id) { - int32_t & tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - const llama_kv_cell & cell = cells[tail_id]; - // partial intersection is invalid - if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { - return false; - } - // invalidate tails which will be cleared - if (p0 <= cell.pos && cell.pos < p1) { - tail_id = -1; - } - } - } else { - // seq_id is negative, then the range should include everything or nothing - if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { - return false; - } - } - - return true; - } - for (uint32_t i = 0; i < size; ++i) { if (cells[i].pos >= p0 && cells[i].pos < p1) { if (seq_id < 0) { @@ -224,7 +173,6 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } cells[i].pos = -1; - cells[i].src = -1; if (new_head == size) { new_head = i; @@ -254,34 +202,6 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id p1 = std::numeric_limits::max(); } - if (recurrent) { - if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { - llama_kv_cell & tail_src = cells[seq_id_src]; - llama_kv_cell & tail_dst = cells[seq_id_dst]; - if (tail_dst.tail >= 0) { - // clear destination seq_id if it wasn't empty - llama_kv_cell & cell_dst = cells[tail_dst.tail]; - - cell_dst.seq_id.erase(seq_id_dst); - tail_dst.tail = -1; - if (cell_dst.seq_id.empty()) { - cell_dst.pos = -1; - cell_dst.delta = -1; - cell_dst.src = -1; - used -= 1; - } - } - if (tail_src.tail >= 0) { - llama_kv_cell & cell_src = cells[tail_src.tail]; - - cell_src.seq_id.insert(seq_id_dst); - tail_dst.tail = tail_src.tail; - } - } - - return; - } - // otherwise, this is the KV of a Transformer-like model head = 0; @@ -296,17 +216,12 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { uint32_t new_head = size; for (uint32_t i = 0; i < size; ++i) { - if (recurrent && (llama_seq_id) i != seq_id) { - cells[i].tail = -1; - } - if (!cells[i].has_seq_id(seq_id)) { if (cells[i].pos >= 0) { used--; } cells[i].pos = -1; - cells[i].src = -1; cells[i].seq_id.clear(); if (new_head == size){ @@ -344,20 +259,6 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po return; } - if (recurrent) { - // for Mamba-like or RWKV models, only the pos needs to be shifted - if (0 <= seq_id && seq_id < (int64_t) size) { - const int32_t tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - llama_kv_cell & cell = cells[tail_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos += delta; - } - } - } - return; - } - for (uint32_t i = 0; i < size; ++i) { if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { has_shift = true; @@ -400,21 +301,6 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po return; } - if (recurrent) { - // for Mamba-like or RWKV models, only the pos needs to be changed - if (0 <= seq_id && seq_id < (int64_t) size) { - const int32_t tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - llama_kv_cell & cell = cells[tail_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos /= d; - } - } - } - - return; - } - for (uint32_t i = 0; i < size; ++i) { if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { has_shift = true; @@ -440,23 +326,11 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { return result; } -void llama_kv_cache_unified::defrag() { - if (!recurrent) { - do_defrag = true; - } -} - void llama_kv_cache_unified::restore() { if (pending.ranges.empty()) { return; } - // TODO: tmp - move to llama_kv_cache_recurrent - if (recurrent) { - seq_rm(-1, -1, -1); - return; - } - uint32_t new_head = size; for (auto & range : pending.ranges) { @@ -469,7 +343,6 @@ void llama_kv_cache_unified::restore() { } cells[i].pos = -1; - cells[i].src = -1; } new_head = std::min(new_head, range.c0); @@ -481,11 +354,6 @@ void llama_kv_cache_unified::restore() { } void llama_kv_cache_unified::commit() { - // TODO: tmp - move to llama_kv_cache_recurrent - if (recurrent) { - return; - } - if (pending.ranges.empty()) { LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n", __func__, "https://github.com/ggml-org/llama.cpp/pull/12695"); @@ -495,183 +363,110 @@ void llama_kv_cache_unified::commit() { pending.ranges.clear(); } -bool llama_kv_cache_unified::get_can_shift() const { - return can_shift; -} +bool llama_kv_cache_unified::update(llama_context & lctx) { + bool need_reserve = false; -bool llama_kv_cache_unified::find_slot( - const llama_ubatch & ubatch) { - const uint32_t n_tokens = ubatch.n_tokens; - const uint32_t n_seqs = ubatch.n_seqs; - const uint32_t n_seq_tokens = ubatch.n_seq_tokens; + auto * sched = lctx.get_sched(); - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (head > used + 2*ubatch.n_tokens) { - head = 0; - } + if (has_shift) { + if (!get_can_shift()) { + GGML_ABORT("The current KV cache / model configuration does not support K-shift"); + } - if (recurrent) { - // For recurrent state architectures (like Mamba or RWKV), - // each cache cell can store the state for a whole sequence. - // A slot should be always be contiguous. + LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); - // can only process batches with an equal number of new tokens in each sequence - GGML_ASSERT(ubatch.equal_seqs); + // apply K-shift if needed + if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { + ggml_backend_sched_reset(sched); - int32_t min = size - 1; - int32_t max = 0; + auto * gf = lctx.graph_init(); - // everything should fit if all seq_ids are smaller than the max - for (uint32_t s = 0; s < n_seqs; ++s) { - const uint32_t n_seq_id = ubatch.n_seq_id[s]; - for (uint32_t j = 0; j < n_seq_id; ++j) { - const llama_seq_id seq_id = ubatch.seq_id[s][j]; + auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf); - if (seq_id < 0 || (uint32_t) seq_id >= size) { - // too big seq_id - // TODO: would it be possible to resize the cache instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size); - return false; - } - if (j > 0) { - llama_kv_cell & seq = cells[seq_id]; - if (seq.tail >= 0) { - llama_kv_cell & cell = cells[seq.tail]; - // clear cells from seq_ids that become shared - // (should not normally happen, but let's handle it anyway) - cell.seq_id.erase(seq_id); - seq.tail = -1; - if (cell.seq_id.empty()) { - cell.pos = -1; - cell.src = -1; - used -= 1; - } - } - } - } + ggml_backend_sched_alloc_graph(sched, gf); + + res->set_inputs(nullptr); + + lctx.graph_compute(gf, false); + + need_reserve = true; } -#ifndef NDEBUG { - std::vector tails_verif; - tails_verif.assign(size, -1); - for (uint32_t i = 0; i < size; ++i) { - llama_kv_cell & cell = cells[i]; - for (llama_seq_id seq_id : cell.seq_id) { - if (tails_verif[seq_id] != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); - } - tails_verif[seq_id] = i; - } - } + has_shift = false; + for (uint32_t i = 0; i < size; ++i) { - if (tails_verif[i] != cells[i].tail) { - LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]); - } + cells[i].delta = 0; } } -#endif + } - // find next empty cell - uint32_t next_empty_cell = head; + if (do_defrag) { + LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - for (uint32_t i = 0; i < size; ++i) { - if (next_empty_cell >= size) { next_empty_cell -= size; } - llama_kv_cell & cell = cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; - } + if (defrag_prepare(lctx.graph_max_nodes())) { + ggml_backend_sched_reset(sched); - // find usable cell range - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - llama_kv_cell & seq_meta = cells[seq_id]; - bool has_cell = false; - if (seq_meta.tail >= 0) { - llama_kv_cell & cell = cells[seq_meta.tail]; - GGML_ASSERT(cell.has_seq_id(seq_id)); - // does this seq_id "own" the cell? - if (cell.seq_id.size() == 1) { has_cell = true; } - } - if (!has_cell) { - llama_kv_cell & empty_cell = cells[next_empty_cell]; - GGML_ASSERT(empty_cell.is_empty()); - // copy old tail into the empty cell - if (seq_meta.tail >= 0) { - llama_kv_cell & orig_cell = cells[seq_meta.tail]; - empty_cell.pos = orig_cell.pos; - empty_cell.src = orig_cell.src; - orig_cell.seq_id.erase(seq_id); - empty_cell.seq_id.insert(seq_id); // will be overwritten - } - seq_meta.tail = next_empty_cell; - // find next empty cell - if (s + 1 < n_seqs) { - next_empty_cell += 1; - for (uint32_t i = 0; i < size; ++i) { - if (next_empty_cell >= size) { next_empty_cell -= size; } - llama_kv_cell & cell = cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; - } - } - } - if (min > seq_meta.tail) { min = seq_meta.tail; } - if (max < seq_meta.tail) { max = seq_meta.tail; } - } + auto * gf = lctx.graph_init(); - // gather and re-order - for (uint32_t s = 0; s < n_seqs; ++s) { - int32_t dst_id = s + min; - int32_t src_id = cells[ubatch.seq_id[s][0]].tail; - if (dst_id != src_id) { - llama_kv_cell & dst_cell = cells[dst_id]; - llama_kv_cell & src_cell = cells[src_id]; + auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf); - std::swap(dst_cell.pos, src_cell.pos); - std::swap(dst_cell.src, src_cell.src); - std::swap(dst_cell.seq_id, src_cell.seq_id); + ggml_backend_sched_alloc_graph(sched, gf); - // swap tails (assuming they NEVER overlap) - for (const llama_seq_id seq_id : src_cell.seq_id) { - cells[seq_id].tail = src_id; - } - for (const llama_seq_id seq_id : dst_cell.seq_id) { - cells[seq_id].tail = dst_id; - } - } - } + res->set_inputs(nullptr); - // update the pos of the used seqs - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; - int32_t cell_id = s + min; - llama_kv_cell & cell = cells[cell_id]; + lctx.graph_compute(gf, false); - if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", - __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); - } - cell.pos = last_pos; - cell.seq_id.clear(); - for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { - const llama_seq_id seq_id = ubatch.seq_id[s][j]; - cell.seq_id.insert(seq_id); - cells[seq_id].tail = cell_id; - } + need_reserve = true; } - // allow getting the range of used cells, from head to head + n - head = min; - n = max - min + 1; - used = std::count_if(cells.begin(), cells.end(), - [](const llama_kv_cell& cell){ return !cell.is_empty(); }); + do_defrag = false; + } + + return need_reserve; +} + +void llama_kv_cache_unified::defrag_sched(float thold) { + // - do not defrag small contexts (i.e. < 2048 tokens) + // - count the padding towards the number of used tokens + const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f; + + // queue defragmentation for next llama_kv_cache_update + if (fragmentation > thold) { + LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); + + do_defrag = true; + } +} + +void llama_kv_cache_unified::set_full() { + n = size; +} + +llama_sbatch llama_kv_cache_unified::sbatch_init( + const llama_batch & batch, + bool logits_all) { + return llama_sbatch(batch, hparams.n_embd, true, logits_all); +} + +llama_ubatch llama_kv_cache_unified::ubatch_next( + llama_sbatch & sbatch, + uint32_t n_ubatch, + bool embd_pooled) const { + GGML_UNUSED(embd_pooled); + return sbatch.split_simple(n_ubatch); +} + +bool llama_kv_cache_unified::find_slot( + const llama_ubatch & ubatch) { + const uint32_t n_tokens = ubatch.n_tokens; + const uint32_t n_seqs = ubatch.n_seqs; + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; - // sanity check - return n >= n_seqs; + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head > used + 2*ubatch.n_tokens) { + head = 0; } // otherwise, one cell per token. @@ -725,24 +520,50 @@ bool llama_kv_cache_unified::find_slot( pending.ranges.push_back({head, head + n_tokens}); + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding))); + + //printf("n = %5d, used = %5d, head = %5d\n", n, used, head); + return true; } -uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const { - // the FA kernels require padding to avoid extra runtime boundary checks - return cparams.flash_attn ? 256u : 32u; +int32_t llama_kv_cache_unified::get_n_tokens() const { + int32_t result = 0; + + for (uint32_t i = 0; i < size; i++) { + result += cells[i].seq_id.size(); + } + + return result; } -uint32_t llama_kv_cache_unified::cell_max() const { - for (uint32_t i = size; i > 0; --i) { - const llama_kv_cell & cell = cells[i - 1]; +int32_t llama_kv_cache_unified::get_used_cells() const { + return used; +} - if (cell.pos >= 0 && !cell.is_empty()) { - return i; - } +bool llama_kv_cache_unified::get_can_shift() const { + return can_shift; +} + +llama_pos llama_kv_cache_unified::get_pos_max() const { + llama_pos pos_max = -1; + for (const auto & cell : cells) { + pos_max = std::max(pos_max, cell.pos); } - return 0; + return pos_max; +} + +size_t llama_kv_cache_unified::total_size() const { + size_t size = 0; + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; } size_t llama_kv_cache_unified::size_k_bytes() const { @@ -765,68 +586,331 @@ size_t llama_kv_cache_unified::size_v_bytes() const { return size_v_bytes; } -bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { - const uint32_t n_layer = hparams.n_layer; +ggml_tensor * llama_kv_cache_unified::build_rope_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale) const { + const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; - const uint32_t n_kv = cell_max(); - const uint32_t n_used = used; + const auto & yarn_ext_factor = cparams.yarn_ext_factor; + const auto & yarn_beta_fast = cparams.yarn_beta_fast; + const auto & yarn_beta_slow = cparams.yarn_beta_slow; - assert(n_used <= n_kv); + const auto & n_rot = hparams.n_rot; + const auto & rope_type = hparams.rope_type; - //const int64_t t_start = ggml_time_us(); + // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; - // number of cells moved - uint32_t n_moves = 0; + ggml_tensor * tmp; - // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag) - // - source view, destination view, copy operation - // - x2 for keys and values - //const uint32_t max_moves = max_nodes()/(6*n_layer); - // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 - const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); + if (ggml_is_quantized(cur->type)) { + // dequantize to f32 -> RoPE -> quantize back + tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); - // determine which KV cells to move where - // - // cell i moves to ids[i] - // - // if ids[i] == i || ids[i] == n_kv, then cell i is not moved - // - auto & ids = defrag_info.ids; + tmp = ggml_rope_ext(ctx, tmp, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); - ids.clear(); - ids.resize(n_kv, n_kv); + tmp = ggml_cpy(ctx, tmp, cur); + } else { + // we rotate only the first n_rot dimensions + tmp = ggml_rope_ext_inplace(ctx, cur, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + } - for (uint32_t i0 = 0; i0 < n_used; ++i0) { - const auto & cell0 = cells[i0]; + return tmp; +} - if (!cell0.is_empty()) { - ids[i0] = i0; +class llm_graph_input_k_shift : public llm_graph_input_i { +public: + llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + virtual ~llm_graph_input_k_shift() = default; - continue; - } + void set_input(const llama_ubatch * ubatch) override; - // found a hole - fill it with data from the end of the cache + ggml_tensor * k_shift; // I32 [kv_size] - uint32_t nh = 1; + const llama_kv_cache_unified * kv_self; +}; - // determine the size of the hole - while (i0 + nh < n_used && cells[i0 + nh].is_empty()) { - nh++; +void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + if (k_shift) { + assert(ggml_backend_buffer_is_host(k_shift->buffer)); + + int32_t * data = (int32_t *) k_shift->data; + + for (uint32_t i = 0; i < kv_self->size; ++i) { + data[i] = kv_self->cells[i].delta; } + } +} - uint32_t nf = 0; - uint32_t is = n_kv - 1; +llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const { + auto res = std::make_unique(); - // starting from the end, find nh non-empty cells - for (; is > i0; --is) { - const auto & cell1 = cells[is]; + const auto & n_layer = hparams.n_layer; - if (cell1.is_empty() || ids[is] != n_kv) { - continue; - } + const auto & n_embd_head_k = hparams.n_embd_head_k; + //const auto & n_embd_head_v = hparams.n_embd_head_v; - // non-empty cell which is not yet moved - nf++; + const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; + + //GGML_ASSERT(kv_self->size == n_ctx); + + auto inp = std::make_unique(this); + + inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx); + ggml_set_input(inp->k_shift); + + for (uint32_t il = 0; il < n_layer; ++il) { + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + + const bool is_swa = hparams.is_swa(il); + + // note: the swa rope params could become part of the cparams in the future + // if we decide to make them configurable, like the non-sliding ones + const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; + const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; + + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); + + ggml_tensor * k = + ggml_view_3d(ctx, k_l[il], + n_embd_head_k, n_head_kv, size, + ggml_row_size(k_l[il]->type, n_embd_head_k), + ggml_row_size(k_l[il]->type, n_embd_k_gqa), + 0); + + ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); + + ggml_build_forward_expand(gf, cur); + } + + res->add_input(std::move(inp)); + + return res; +} + +llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const { + auto res = std::make_unique(); + + const auto & ids = defrag_info.ids; + +#if 0 + // CPU defrag + // + // TODO: optimizations are possible: + // - multiple threads + // - avoid copying to the host memory when already there + // + // likely not worth the effort, as we have ggml_graph based defrag + // + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + + const uint32_t kv_size = size; + + std::vector buf_k; + std::vector buf_v; + + for (uint32_t il = 0; il < n_layer; ++il) { + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); + + const size_t v_size_el = ggml_type_size(v_l[il]->type); + const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); + + buf_k.resize(k_size); + buf_v.resize(v_size); + + ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); + + // batch move [i, i+nm) to [id, id+nm) + // note: cells can move only to a lower index + for (uint32_t i = 0; i < n_kv; ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == n_kv) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < n_kv && ids[i + nm] == id + nm) { + nm++; + } + + // move keys + { + const int64_t os = i*k_size_row; + const int64_t od = id*k_size_row; + + memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); + } + + // move values (note: they are transposed) + { + const int64_t os = i; + const int64_t od = id; + + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); + } + } + + i += nm - 1; + } + + ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); + } +#else + for (uint32_t i = 0; i < ids.size(); ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == ids.size()) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < ids.size() && ids[i + nm] == id + nm) { + nm++; + } + + for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + + ggml_tensor * view_k_src = ggml_view_2d(ctx, k_l[il], + n_embd_k_gqa, nm, + ggml_row_size(k_l[il]->type, n_embd_k_gqa), + ggml_row_size(k_l[il]->type, n_embd_k_gqa*i)); + + ggml_tensor * view_k_dst = ggml_view_2d(ctx, k_l[il], + n_embd_k_gqa, nm, + ggml_row_size(k_l[il]->type, n_embd_k_gqa), + ggml_row_size(k_l[il]->type, n_embd_k_gqa*id)); + + ggml_tensor * view_v_src; + ggml_tensor * view_v_dst; + + if (cparams.flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = ggml_view_2d(ctx, v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(v_l[il]->type, n_embd_v_gqa), + ggml_row_size(v_l[il]->type, n_embd_v_gqa*i)); + + view_v_dst = ggml_view_2d(ctx, v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(v_l[il]->type, n_embd_v_gqa), + ggml_row_size(v_l[il]->type, n_embd_v_gqa*id)); + } else { + view_v_src = ggml_view_2d(ctx, v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(v_l[il]->type, size), + ggml_row_size(v_l[il]->type, i)); + + view_v_dst = ggml_view_2d(ctx, v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(v_l[il]->type, size), + ggml_row_size(v_l[il]->type, id)); + } + + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst)); + } + + i += nm - 1; + } + + //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); +#endif + + return res; +} + +bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { + const uint32_t n_layer = hparams.n_layer; + + const uint32_t n_kv = cell_max(); + const uint32_t n_used = used; + + assert(n_used <= n_kv); + + //const int64_t t_start = ggml_time_us(); + + // number of cells moved + uint32_t n_moves = 0; + + // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag) + // - source view, destination view, copy operation + // - x2 for keys and values + //const uint32_t max_moves = max_nodes()/(6*n_layer); + // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 + const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); + + // determine which KV cells to move where + // + // cell i moves to ids[i] + // + // if ids[i] == i || ids[i] == n_kv, then cell i is not moved + // + auto & ids = defrag_info.ids; + + ids.clear(); + ids.resize(n_kv, n_kv); + + for (uint32_t i0 = 0; i0 < n_used; ++i0) { + const auto & cell0 = cells[i0]; + + if (!cell0.is_empty()) { + ids[i0] = i0; + + continue; + } + + // found a hole - fill it with data from the end of the cache + + uint32_t nh = 1; + + // determine the size of the hole + while (i0 + nh < n_used && cells[i0 + nh].is_empty()) { + nh++; + } + + uint32_t nf = 0; + uint32_t is = n_kv - 1; + + // starting from the end, find nh non-empty cells + for (; is > i0; --is) { + const auto & cell1 = cells[is]; + + if (cell1.is_empty() || ids[is] != n_kv) { + continue; + } + + // non-empty cell which is not yet moved + nf++; if (nf == nh) { break; @@ -867,7 +951,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { cells[i0 + nf] = cell1; // clear the old cell and move the head there - cell1 = llama_kv_cell(); + cell1 = kv_cell(); head = n_used; if (!cont) { @@ -895,13 +979,25 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { return false; } - LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves); + LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); - LLAMA_LOG_DEBUG("expected gf nodes: %u\n", 6*n_moves*n_layer); + LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); return true; } +uint32_t llama_kv_cache_unified::cell_max() const { + for (uint32_t i = size; i > 0; --i) { + const kv_cell & cell = cells[i - 1]; + + if (cell.pos >= 0 && !cell.is_empty()) { + return i; + } + } + + return 0; +} + void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { std::vector> cell_ranges; // ranges, from inclusive, to exclusive uint32_t cell_count = 0; @@ -1110,7 +1206,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell clear(); for (uint32_t i = 0; i < cell_count; ++i) { - llama_kv_cell & cell = cells[i]; + kv_cell & cell = cells[i]; llama_pos pos; uint32_t n_seq_id; @@ -1133,15 +1229,6 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell } cell.seq_id.insert(seq_id); - - if (recurrent) { - int32_t & tail = cells[seq_id].tail; - if (tail != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); - return false; - } - tail = i; - } } } @@ -1149,14 +1236,6 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell used = cell_count; } - if (recurrent) { - for (uint32_t i = 0; i < cell_count; ++i) { - uint32_t cell_id = head + i; - // make sure the recurrent states will keep their restored state - cells[cell_id].src = cell_id; - } - } - return true; } @@ -1174,7 +1253,1034 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); return false; } - if (v_trans != (bool) v_trans) { + if (this->v_trans != (bool) v_trans) { + LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); + return false; + } + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Read type of key + int32_t k_type_i_ref; + io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + const int32_t k_type_i = (int32_t) k_l[il]->type; + if (k_type_i != k_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t k_size_row_ref; + io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + if (k_size_row != k_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + } + } + + if (!this->v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read row size of value + uint64_t v_size_row_ref; + io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read element size of value + uint32_t v_size_el_ref; + io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + const size_t v_size_el = ggml_type_size(v_l[il]->type); + if (v_size_el != v_size_el_ref) { + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); + return false; + } + + // Read GQA embedding size + uint32_t n_embd_v_gqa_ref; + io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + return false; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (head + j * size) * v_size_el; + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } + } + } + + return true; +} + +// +// llama_kv_cache_recurrent +// + +llama_kv_cache_recurrent::llama_kv_cache_recurrent( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size) : hparams(model.hparams) { + const int32_t n_layer = hparams.n_layer; + + LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n", + __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer); + + head = 0; + size = kv_size; + used = 0; + + this->type_k = type_k; + this->type_v = type_v; + + cells.clear(); + cells.resize(kv_size); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map[buft] = ctx; + ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + k_l.reserve(n_layer); + v_l.reserve(n_layer); + + for (int i = 0; i < n_layer; i++) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + + const char * dev_name = "CPU"; + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + + if (offload) { + auto * dev = model.dev_layer(i); + buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); + } + + LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name); + + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + throw std::runtime_error("failed to create ggml context for kv cache"); + } + + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + k_l.push_back(k); + v_l.push_back(v); + } + + // allocate tensors and initialize the buffers to avoid NaNs in the padding + for (auto it : ctx_map) { + auto * buft = it.first; + auto * ctx = it.second; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + throw std::runtime_error("failed to allocate buffer for kv cache"); + } + ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + bufs.emplace_back(buf); + } + + { + const size_t memory_size_k = size_k_bytes(); + const size_t memory_size_v = size_v_bytes(); + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } +} + +void llama_kv_cache_recurrent::clear() { + for (int32_t i = 0; i < (int32_t) size; ++i) { + cells[i].pos = -1; + cells[i].seq_id.clear(); + cells[i].src = -1; + cells[i].tail = -1; + } + head = 0; + used = 0; + + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } +} + +bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + uint32_t new_head = size; + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // models like Mamba or RWKV can't have a state partially erased + if (seq_id >= (int64_t) size) { + // could be fatal + return false; + } + if (0 <= seq_id) { + int32_t & tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + const kv_cell & cell = cells[tail_id]; + // partial intersection is invalid + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + // invalidate tails which will be cleared + if (p0 <= cell.pos && cell.pos < p1) { + tail_id = -1; + } + } + } else { + // seq_id is negative, then the range should include everything or nothing + if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + return false; + } + } + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].pos >= p0 && cells[i].pos < p1) { + if (seq_id < 0) { + cells[i].seq_id.clear(); + } else if (cells[i].has_seq_id(seq_id)) { + cells[i].seq_id.erase(seq_id); + } else { + continue; + } + if (cells[i].is_empty()) { + // keep count of the number of used cells + if (cells[i].pos >= 0) { + used--; + } + cells[i].pos = -1; + cells[i].src = -1; + if (new_head == size) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != size && new_head < head) { + head = new_head; + } + + return true; +} + +void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { + kv_cell & tail_src = cells[seq_id_src]; + kv_cell & tail_dst = cells[seq_id_dst]; + if (tail_dst.tail >= 0) { + // clear destination seq_id if it wasn't empty + kv_cell & cell_dst = cells[tail_dst.tail]; + + cell_dst.seq_id.erase(seq_id_dst); + tail_dst.tail = -1; + if (cell_dst.seq_id.empty()) { + cell_dst.pos = -1; + cell_dst.src = -1; + used -= 1; + } + } + if (tail_src.tail >= 0) { + kv_cell & cell_src = cells[tail_src.tail]; + + cell_src.seq_id.insert(seq_id_dst); + tail_dst.tail = tail_src.tail; + } + } +} + +void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) { + uint32_t new_head = size; + + for (uint32_t i = 0; i < size; ++i) { + if ((llama_seq_id) i != seq_id) { + cells[i].tail = -1; + } + + if (!cells[i].has_seq_id(seq_id)) { + if (cells[i].pos >= 0) { + used--; + } + + cells[i].pos = -1; + cells[i].src = -1; + cells[i].seq_id.clear(); + + if (new_head == size){ + new_head = i; + } + } else { + cells[i].seq_id.clear(); + cells[i].seq_id.insert(seq_id); + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != size && new_head < head) { + head = new_head; + } +} + +void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + if (delta == 0) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the + if (p0 == p1) { + return; + } + + // for Mamba-like or RWKV models, only the pos needs to be shifted + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + kv_cell & cell = cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos += delta; + } + } + } +} + +void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) { + return; + } + + // for Mamba-like or RWKV models, only the pos needs to be changed + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + kv_cell & cell = cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos /= d; + } + } + } +} + +llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { + llama_pos result = 0; + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + result = std::max(result, cells[i].pos); + } + } + + return result; +} + +void llama_kv_cache_recurrent::restore() { + if (pending.ranges.empty()) { + return; + } + + seq_rm(-1, -1, -1); +} + +void llama_kv_cache_recurrent::commit() { + pending.ranges.clear(); +} + +bool llama_kv_cache_recurrent::update(llama_context & lctx) { + GGML_UNUSED(lctx); + return false; +} + +void llama_kv_cache_recurrent::defrag_sched(float thold) { + GGML_UNUSED(thold); + // noop +} + +void llama_kv_cache_recurrent::set_full() { + n = size; +} + +llama_sbatch llama_kv_cache_recurrent::sbatch_init( + const llama_batch & batch, + bool logits_all) { + return llama_sbatch(batch, hparams.n_embd, false, logits_all); +} + +llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + return sbatch.split_seq(n_ubatch); + } + + return sbatch.split_equal(n_ubatch); +} + +bool llama_kv_cache_recurrent::find_slot( + const llama_ubatch & ubatch) { + const uint32_t n_tokens = ubatch.n_tokens; + const uint32_t n_seqs = ubatch.n_seqs; + + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head > used + 2*n_tokens) { + head = 0; + } + + // For recurrent state architectures (like Mamba or RWKV), + // each cache cell can store the state for a whole sequence. + // A slot should be always be contiguous. + + // can only process batches with an equal number of new tokens in each sequence + GGML_ASSERT(ubatch.equal_seqs); + + int32_t min = size - 1; + int32_t max = 0; + + // everything should fit if all seq_ids are smaller than the max + for (uint32_t s = 0; s < n_seqs; ++s) { + const uint32_t n_seq_id = ubatch.n_seq_id[s]; + for (uint32_t j = 0; j < n_seq_id; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + + if (seq_id < 0 || (uint32_t) seq_id >= size) { + // too big seq_id + // TODO: would it be possible to resize the cache instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size); + return false; + } + if (j > 0) { + kv_cell & seq = cells[seq_id]; + if (seq.tail >= 0) { + kv_cell & cell = cells[seq.tail]; + // clear cells from seq_ids that become shared + // (should not normally happen, but let's handle it anyway) + cell.seq_id.erase(seq_id); + seq.tail = -1; + if (cell.seq_id.empty()) { + cell.pos = -1; + cell.src = -1; + used -= 1; + } + } + } + } + } + +#ifndef NDEBUG + { + std::vector tails_verif; + tails_verif.assign(size, -1); + for (uint32_t i = 0; i < size; ++i) { + kv_cell & cell = cells[i]; + for (llama_seq_id seq_id : cell.seq_id) { + if (tails_verif[seq_id] != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); + } + tails_verif[seq_id] = i; + } + } + for (uint32_t i = 0; i < size; ++i) { + if (tails_verif[i] != cells[i].tail) { + LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]); + } + } + } +#endif + + // find next empty cell + uint32_t next_empty_cell = head; + + for (uint32_t i = 0; i < size; ++i) { + if (next_empty_cell >= size) { next_empty_cell -= size; } + kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + + // find usable cell range + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + kv_cell & seq_meta = cells[seq_id]; + bool has_cell = false; + if (seq_meta.tail >= 0) { + kv_cell & cell = cells[seq_meta.tail]; + GGML_ASSERT(cell.has_seq_id(seq_id)); + // does this seq_id "own" the cell? + if (cell.seq_id.size() == 1) { has_cell = true; } + } + if (!has_cell) { + kv_cell & empty_cell = cells[next_empty_cell]; + GGML_ASSERT(empty_cell.is_empty()); + // copy old tail into the empty cell + if (seq_meta.tail >= 0) { + kv_cell & orig_cell = cells[seq_meta.tail]; + empty_cell.pos = orig_cell.pos; + empty_cell.src = orig_cell.src; + orig_cell.seq_id.erase(seq_id); + empty_cell.seq_id.insert(seq_id); // will be overwritten + } + seq_meta.tail = next_empty_cell; + // find next empty cell + if (s + 1 < n_seqs) { + next_empty_cell += 1; + for (uint32_t i = 0; i < size; ++i) { + if (next_empty_cell >= size) { next_empty_cell -= size; } + kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + } + } + if (min > seq_meta.tail) { min = seq_meta.tail; } + if (max < seq_meta.tail) { max = seq_meta.tail; } + } + + // gather and re-order + for (uint32_t s = 0; s < n_seqs; ++s) { + int32_t dst_id = s + min; + int32_t src_id = cells[ubatch.seq_id[s][0]].tail; + if (dst_id != src_id) { + kv_cell & dst_cell = cells[dst_id]; + kv_cell & src_cell = cells[src_id]; + + std::swap(dst_cell.pos, src_cell.pos); + std::swap(dst_cell.src, src_cell.src); + std::swap(dst_cell.seq_id, src_cell.seq_id); + + // swap tails (assuming they NEVER overlap) + for (const llama_seq_id seq_id : src_cell.seq_id) { + cells[seq_id].tail = src_id; + } + for (const llama_seq_id seq_id : dst_cell.seq_id) { + cells[seq_id].tail = dst_id; + } + } + } + + // update the pos of the used seqs + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + int32_t cell_id = s + min; + kv_cell & cell = cells[cell_id]; + + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", + __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); + } + cell.pos = last_pos; + cell.seq_id.clear(); + for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + cell.seq_id.insert(seq_id); + cells[seq_id].tail = cell_id; + } + } + + // allow getting the range of used cells, from head to head + n + head = min; + n = max - min + 1; + used = std::count_if(cells.begin(), cells.end(), + [](const kv_cell & cell){ return !cell.is_empty(); }); + + // sanity check + return n >= n_seqs; +} + +int32_t llama_kv_cache_recurrent::get_n_tokens() const { + int32_t result = 0; + + for (uint32_t i = 0; i < size; i++) { + result += cells[i].seq_id.size(); + } + + return result; +} + +int32_t llama_kv_cache_recurrent::get_used_cells() const { + return used; +} + +llama_pos llama_kv_cache_recurrent::get_pos_max() const { + llama_pos pos_max = -1; + for (const auto & cell : cells) { + pos_max = std::max(pos_max, cell.pos); + } + + return pos_max; +} + +bool llama_kv_cache_recurrent::get_can_shift() const { + return false; +} + +int32_t llama_kv_cache_recurrent::s_copy(int i) const { + const uint32_t cell_id = i + head; + + ////////////////////////////////////////////// + // TODO: this should not mutate the KV cache ! + kv_cell & cell = const_cast(cells[cell_id]); + + // prevent out-of-bound sources + if (cell.src < 0 || (uint32_t) cell.src >= size) { + cell.src = cell_id; + } + + int32_t res = cell.src; + + // TODO: do not mutate the KV cache + // ensure copy only happens once + if (cell.src != (int32_t) cell_id) { + cell.src = cell_id; + } + + return res; +} + +float llama_kv_cache_recurrent::s_mask(int i) const { + const uint32_t cell_id = i + head; + + ////////////////////////////////////////////// + // TODO: this should not mutate the KV cache ! + kv_cell & cell = const_cast(cells[cell_id]); + + float res = (float) (cell.src >= 0); + + // only clear once + if (cell.src < 0) { + cell.src = cell_id; + } + + return res; +} + +uint32_t llama_kv_cache_recurrent::cell_max() const { + for (uint32_t i = size; i > 0; --i) { + const kv_cell & cell = cells[i - 1]; + + if (cell.pos >= 0 && !cell.is_empty()) { + return i; + } + } + + return 0; +} + +size_t llama_kv_cache_recurrent::total_size() const { + size_t size = 0; + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; +} + +size_t llama_kv_cache_recurrent::size_k_bytes() const { + size_t size_k_bytes = 0; + + for (const auto & k : k_l) { + size_k_bytes += ggml_nbytes(k); + } + + return size_k_bytes; +} + +size_t llama_kv_cache_recurrent::size_v_bytes() const { + size_t size_v_bytes = 0; + + for (const auto & v : v_l) { + size_v_bytes += ggml_nbytes(v); + } + + return size_v_bytes; +} + +void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = size; + for (uint32_t i = 0; i < size; ++i) { + const auto & cell = cells[i]; + if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { + ++cell_count; + if (cell_range_begin == size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = size; + } + } + } + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, size); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + + io.write(&cell_count, sizeof(cell_count)); + + state_write_meta(io, cell_ranges, seq_id); + state_write_data(io, cell_ranges); +} + +void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); + + bool res = true; + res = res && state_read_meta(io, cell_count, seq_id); + res = res && state_read_data(io, cell_count); + + if (!res) { + if (seq_id == -1) { + clear(); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); + } +} + +void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = cells[i]; + const llama_pos pos = cell.pos; + const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; + + io.write(&pos, sizeof(pos)); + io.write(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id) { + for (auto seq_id : cell.seq_id) { + io.write(&seq_id, sizeof(seq_id)); + } + } + } + } +} + +void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { + const uint32_t v_trans = 0; + const uint32_t n_layer = hparams.n_layer; + + io.write(&v_trans, sizeof(v_trans)); + io.write(&n_layer, sizeof(n_layer)); + + std::vector tmp_buf; + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Write key type + const int32_t k_type_i = (int32_t)k_l[il]->type; + io.write(&k_type_i, sizeof(k_type_i)); + + // Write row size of key + const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + io.write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * k_size_row; + io.write_tensor(k_l[il], range.first * k_size_row, buf_size); + } + } + + if (!v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write row size of value + const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + io.write(&v_size_row, sizeof(v_size_row)); + + // Read each range of cells of v_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * v_size_row; + io.write_tensor(v_l[il], range.first * v_size_row, buf_size); + } + } + } else { + // When v is transposed, we also need the element size and get the element ranges from each row + const uint32_t kv_size = size; + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const uint32_t v_size_el = ggml_type_size(v_l[il]->type); + io.write(&v_size_el, sizeof(v_size_el)); + + // Write GQA embedding size + io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + const size_t buf_size = range_size * v_size_el; + io.write_tensor(v_l[il], src_offset, buf_size); + } + } + } + } +} + +bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { + if (dest_seq_id != -1) { + // single sequence + + seq_rm(dest_seq_id, -1, -1); + + llama_sbatch sbatch; + llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + + batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; + batch.n_seqs = 1; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 0) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + batch.pos[i] = pos; + } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &dest_seq_id; + if (!find_slot(batch)) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + commit(); + + // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(head + cell_count <= size); + GGML_ASSERT(cells[head].pos == batch.pos[0]); + GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); + GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); + } else { + // whole KV cache restore + + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } + + clear(); + + for (uint32_t i = 0; i < cell_count; ++i) { + kv_cell & cell = cells[i]; + + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + cell.pos = pos; + + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + + // TODO: llama_kv_cache_recurrent should have a notion of max sequences + //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + if (seq_id < 0) { + //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); + return false; + } + + cell.seq_id.insert(seq_id); + + int32_t & tail = cells[seq_id].tail; + if (tail != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); + return false; + } + tail = i; + } + } + + head = 0; + used = cell_count; + } + + for (uint32_t i = 0; i < cell_count; ++i) { + uint32_t cell_id = head + i; + // make sure the recurrent states will keep their restored state + cells[cell_id].src = cell_id; + } + + return true; +} + +bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) { + uint32_t v_trans; + uint32_t n_layer; + io.read_to(&v_trans, sizeof(v_trans)); + io.read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != hparams.n_layer) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + return false; + } + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); + return false; + } + if (false != (bool) v_trans) { LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); return false; } @@ -1326,7 +2432,7 @@ void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache view->cells_sequences = (llama_seq_id *)p; } - const std::vector & kv_cells = kvu->cells; + const std::vector & kv_cells = kvu->cells; llama_kv_cache_view_cell * c_curr = view->cells; llama_seq_id * cs_curr = view->cells_sequences; int32_t used_cells = 0; diff --git a/examples/talk-llama/llama-kv-cache.h b/examples/talk-llama/llama-kv-cache.h index 56c74035ae1..bf3b4b6a443 100644 --- a/examples/talk-llama/llama-kv-cache.h +++ b/examples/talk-llama/llama-kv-cache.h @@ -2,32 +2,72 @@ #include "llama.h" #include "llama-io.h" +#include "llama-graph.h" #include "llama-memory.h" #include "ggml-cpp.h" -#include #include #include struct llama_cparams; struct llama_hparams; struct llama_ubatch; +struct llama_sbatch; +struct llama_model; +struct llama_context; struct llama_kv_cache : public llama_memory_i { - using llama_memory_i::llama_memory_i; + virtual ~llama_kv_cache() = default; - virtual void restore() = 0; // call if batch processing fails - restores the cache state - virtual void commit() = 0; // call after successful batch processing - clears any pending state + // call if batch processing fails - restores the cache state + virtual void restore() = 0; - virtual int32_t get_n_tokens() const = 0; - virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache + // call after successful batch processing - clears any pending state + virtual void commit() = 0; - virtual bool get_can_shift() const = 0; + // process any pending defrag/shift/etc. operations + // optionally call once before processing a new batch + virtual bool update(llama_context & lctx) = 0; + + // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing + virtual void defrag_sched(float thold) = 0; + + // simulate full cache, used for allocating worst-case compute buffers + virtual void set_full() = 0; + + // + // batch processing + // + + virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0; + + // different KV caches require different batch splitting strategies + virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0; + + // find an empty slot of size "n_tokens" in the cache + virtual bool find_slot(const llama_ubatch & batch) = 0; + + // getters + virtual int32_t get_n_tokens() const = 0; + virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache + virtual llama_pos get_pos_max() const = 0; + virtual bool get_can_shift() const = 0; bool get_can_edit() const override { return get_can_shift(); } + + // + // state write/read + // + + virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; + virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; }; +// +// llama_kv_cache_guard +// + struct llama_kv_cache_guard { llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {} @@ -43,65 +83,50 @@ struct llama_kv_cache_guard { llama_kv_cache * kv; }; -struct llama_kv_cell { - llama_pos pos = -1; - llama_pos delta = 0; - int32_t src = -1; // used by recurrent state models to copy states - int32_t tail = -1; +// +// llama_kv_cache_unified +// - std::set seq_id; +// TODO: add notion of max sequences +class llama_kv_cache_unified : public llama_kv_cache { +public: + struct kv_cell { + llama_pos pos = -1; + llama_pos delta = 0; - bool has_seq_id(const llama_seq_id & id) const { - return seq_id.find(id) != seq_id.end(); - } + std::set seq_id; - bool is_empty() const { - return seq_id.empty(); - } + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } - bool is_same_seq(const llama_kv_cell & other) const { - return seq_id == other.seq_id; - } -}; + bool is_empty() const { + return seq_id.empty(); + } -// ring-buffer of cached KV data -// TODO: pimpl -// TODO: add notion of max sequences -class llama_kv_cache_unified : public llama_kv_cache { -public: - // can be used to query data from the model if needed - struct callbacks { - std::function get_rope_factors; + bool is_same_seq(const kv_cell & other) const { + return seq_id == other.seq_id; + } }; - llama_kv_cache_unified( - const llama_hparams & hparams, - callbacks cbs); - - virtual ~llama_kv_cache_unified() = default; + static uint32_t get_padding(const llama_cparams & cparams); - // TODO: become constructor - bool init( - const llama_model & model, // TODO: do not reference the model - const llama_cparams & cparams, + llama_kv_cache_unified( + const llama_model & model, ggml_type type_k, ggml_type type_v, + bool v_trans, + bool offload, uint32_t kv_size, - bool offload); - - int32_t get_n_tokens() const override; - int32_t get_used_cells() const override; + uint32_t padding); - size_t total_size() const; + ~llama_kv_cache_unified() = default; - // TODO: better data structures to reduce the cost of this operation - llama_pos pos_max() const; + // + // llama_memory_i + // void clear() override; - void defrag() override; - - virtual void restore() override; - virtual void commit() override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; @@ -111,25 +136,76 @@ class llama_kv_cache_unified : public llama_kv_cache { llama_pos seq_pos_max(llama_seq_id seq_id) const override; - bool get_can_shift() const override; + // + // llama_kv_cache + // + + void restore() override; + void commit() override; + + bool update(llama_context & ctx) override; + + void defrag_sched(float thold) override; + + void set_full() override; + + llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; + + llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; - // find an empty slot of size "n_tokens" in the cache // updates the cache head // Note: On success, it's important that cache.head points // to the first cell of the slot. - bool find_slot(const llama_ubatch & batch); + bool find_slot(const llama_ubatch & batch) override; - // TODO: maybe not needed - uint32_t get_padding(const llama_cparams & cparams) const; + int32_t get_n_tokens() const override; + int32_t get_used_cells() const override; - // find how many cells are currently in use - uint32_t cell_max() const; + // TODO: better data structures to reduce the cost of this operation + llama_pos get_pos_max() const override; - size_t size_k_bytes() const; - size_t size_v_bytes() const; + bool get_can_shift() const override; - // defrag + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + // Note: The value of head isn't only used to optimize searching + // for a free KV slot. llama_decode_impl also uses it, so it + // cannot be freely changed after a slot has been allocated. + uint32_t head = 0; + uint32_t size = 0; + uint32_t used = 0; // used cells (i.e. at least one seq_id) + + // computed before each graph build + uint32_t n = 0; + + std::vector cells; + + std::vector k_l; // per layer + std::vector v_l; + +private: + const llama_model & model; + const llama_hparams & hparams; + + bool has_shift = false; + bool do_defrag = false; + + bool v_trans = true; // the value tensor is transposed + bool can_shift = false; + + // required padding + uint32_t padding = 1; + + ggml_type type_k = GGML_TYPE_F16; + ggml_type type_v = GGML_TYPE_F16; + + std::vector ctxs; + std::vector bufs; + // defrag struct { std::vector ids; } defrag_info; @@ -138,7 +214,6 @@ class llama_kv_cache_unified : public llama_kv_cache { bool defrag_prepare(int32_t n_max_nodes); // commit/restore cache - struct slot_range { uint32_t c0 = 0; // note: these are cell indices, not sequence positions uint32_t c1 = 0; @@ -149,25 +224,124 @@ class llama_kv_cache_unified : public llama_kv_cache { std::vector ranges; } pending; - // state write/load + // find how many cells are currently in use + uint32_t cell_max() const; - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1); + size_t total_size() const; - // members + size_t size_k_bytes() const; + size_t size_v_bytes() const; - const llama_hparams & hparams; + ggml_tensor * build_rope_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale) const; + + llm_graph_result_ptr build_graph_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const; + + llm_graph_result_ptr build_graph_defrag( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const; - callbacks cbs; + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; - bool has_shift = false; - bool do_defrag = false; + bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t cell_count); +}; - // TODO: remove this and implement llama_kv_cache_recurrent instead - bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token +// +// llama_kv_cache_recurrent +// - bool v_trans = true; // the value tensor is transposed - bool can_shift = false; +class llama_kv_cache_recurrent : public llama_kv_cache { +public: + struct kv_cell { + llama_pos pos = -1; + int32_t src = -1; // used to copy states + int32_t tail = -1; + + std::set seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } + + bool is_empty() const { + return seq_id.empty(); + } + + bool is_same_seq(const kv_cell & other) const { + return seq_id == other.seq_id; + } + }; + + llama_kv_cache_recurrent( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size); + + ~llama_kv_cache_recurrent() = default; + + // + // llama_memory_i + // + + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // + // llama_kv_cache + // + + void restore() override; + void commit() override; + + bool update(llama_context & lctx) override; + + void defrag_sched(float thold) override; + + void set_full() override; + + llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; + + llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; + + bool find_slot(const llama_ubatch & batch) override; + + int32_t get_n_tokens() const override; + int32_t get_used_cells() const override; + + // TODO: better data structures to reduce the cost of this operation + llama_pos get_pos_max() const override; + + bool get_can_shift() const override; + + // TODO: temporary methods - they are not really const as they do const_cast<>, fix this + int32_t s_copy(int i) const; + float s_mask(int i) const; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_impl also uses it, so it @@ -179,18 +353,41 @@ class llama_kv_cache_unified : public llama_kv_cache { // computed before each graph build uint32_t n = 0; - std::vector cells; + std::vector cells; std::vector k_l; // per layer std::vector v_l; private: + //const llama_model & model; + const llama_hparams & hparams; + + // commit/restore cache + // TODO: rework for recurrent cache + struct slot_range { + uint32_t c0 = 0; // note: these are cell indices, not sequence positions + uint32_t c1 = 0; + }; + + // pending cell updates that are not yet committed + struct { + std::vector ranges; + } pending; + ggml_type type_k = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16; std::vector ctxs; std::vector bufs; + // find how many cells are currently in use + uint32_t cell_max() const; + + size_t total_size() const; + + size_t size_k_bytes() const; + size_t size_v_bytes() const; + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; @@ -198,11 +395,6 @@ class llama_kv_cache_unified : public llama_kv_cache { bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; -// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified -//class llama_kv_cache_recurrent : public llama_kv_cache_unified { -//public: -// using llama_kv_cache_unified::llama_kv_cache_unified; -//}; // // kv cache view diff --git a/examples/talk-llama/llama-memory.h b/examples/talk-llama/llama-memory.h index dfa8c4e90fc..c7412d5911e 100644 --- a/examples/talk-llama/llama-memory.h +++ b/examples/talk-llama/llama-memory.h @@ -2,12 +2,22 @@ #include "llama.h" +struct llama_memory_params { + // kv cache + ggml_type type_k; + ggml_type type_v; + + // parameters for other types of memory + // ... +}; + // general concept of LLM memory // the KV cache is a type of LLM memory, but there can be other types class llama_memory_i { public: + virtual ~llama_memory_i() = default; + virtual void clear() = 0; - virtual void defrag() = 0; virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index ea73a8a7ba9..4cce51668b4 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -301,12 +301,12 @@ namespace GGUFMeta { GGUFMeta::GKV::get_kv(meta.get(), kid); switch (arr_info.gt) { - case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; - case GGUF_TYPE_INT32: GGML_ASSERT( - (std::is_same::value) || - (std::is_same::value)); break; + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value) || + (std::is_same::value)); break; + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; default: - throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str())); } result.resize(arr_info.length); @@ -330,12 +330,12 @@ namespace GGUFMeta { GGUFMeta::GKV::get_kv(meta.get(), kid); switch (arr_info.gt) { - case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; - case GGUF_TYPE_INT32: GGML_ASSERT( - (std::is_same::value) || - (std::is_same::value)); break; + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value) || + (std::is_same::value)); break; + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; default: - throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str())); } if (arr_info.length > N_MAX) { @@ -823,6 +823,10 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps mmaps_used.reserve(files.size()); for (const auto & file : files) { auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); + if (!reg) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } + auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); std::unique_ptr mapping = std::make_unique(file.get(), prefetch ? -1 : 0, is_numa_fn()); mmaps_used.emplace_back(mapping->size(), 0); diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp new file mode 100644 index 00000000000..a70b9892347 --- /dev/null +++ b/examples/talk-llama/llama-model-saver.cpp @@ -0,0 +1,281 @@ +#include "llama-model-saver.h" + +#include "gguf.h" + +#include "llama.h" +#include "llama-hparams.h" +#include "llama-model.h" +#include "llama-vocab.h" + +#include + +llama_model_saver::llama_model_saver(const struct llama_model & model) : model(model), llm_kv(model.arch) { + gguf_ctx = gguf_init_empty(); +} + +llama_model_saver::~llama_model_saver() { + gguf_free(gguf_ctx); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const uint32_t value) { + gguf_set_val_u32(gguf_ctx, llm_kv(key).c_str(), value); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const int32_t value) { + gguf_set_val_i32(gguf_ctx, llm_kv(key).c_str(), value); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const float value) { + gguf_set_val_f32(gguf_ctx, llm_kv(key).c_str(), value); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const bool value) { + gguf_set_val_bool(gguf_ctx, llm_kv(key).c_str(), value); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const char * value) { + gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), value); +} + +[[noreturn]] +void llama_model_saver::add_kv(const enum llm_kv key, const char value) { + GGML_UNUSED(key); + GGML_UNUSED(value); + GGML_ABORT("fatal error"); // this should never be called, only needed to make the template below compile +} + +template +void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, const bool per_layer) { + const size_t n_values = per_layer ? size_t(model.hparams.n_layer) : value.size(); + GGML_ASSERT(n_values <= value.size()); + + if (n_values == 0) { + return; + } + + if (per_layer) { + bool all_values_the_same = true; + for (size_t i = 1; i < n_values; ++i) { + if (value[i] != value[0]) { + all_values_the_same = false; + break; + } + } + if (all_values_the_same) { + add_kv(key, value[0]); + return; + } + } + + if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT8, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT8, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT32, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT32, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_FLOAT32, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), reinterpret_cast(value.data())); + } else { + GGML_ABORT("fatal error"); + } +} + +void llama_model_saver::add_kv(const enum llm_kv key, const std::vector & value) { + std::vector tmp(value.size()); + for (size_t i = 0; i < value.size(); ++i) { + tmp[i] = value[i].c_str(); + } + gguf_set_arr_str(gguf_ctx, llm_kv(key).c_str(), tmp.data(), tmp.size()); +} + +void llama_model_saver::add_tensor(const struct ggml_tensor * tensor) { + if (!tensor) { + return; + } + if (gguf_find_tensor(gguf_ctx, tensor->name) >= 0) { + GGML_ASSERT(std::string(tensor->name) == "rope_freqs.weight"); // FIXME + return; + } + gguf_add_tensor(gguf_ctx, tensor); +} + +void llama_model_saver::add_kv_from_model() { + const llama_hparams & hparams = model.hparams; + const llama_vocab & vocab = model.vocab; + + const int32_t n_vocab = vocab.n_tokens(); + std::vector tokens(n_vocab); + std::vector scores(n_vocab); + std::vector token_types(n_vocab); + + for (int32_t id = 0; id < n_vocab; ++id) { + const llama_vocab::token_data & token_data = vocab.get_token_data(id); + + tokens[id] = token_data.text; + scores[id] = token_data.score; + + switch(token_data.attr) { + case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break; + case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break; + case LLAMA_TOKEN_ATTR_NORMAL: token_types[id] = LLAMA_TOKEN_TYPE_NORMAL; break; + case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break; + case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break; + case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break; + case LLAMA_TOKEN_ATTR_UNDEFINED: + default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break; + } + } + + // add_kv(LLM_KV_GENERAL_TYPE, ???); + add_kv(LLM_KV_GENERAL_ARCHITECTURE, model.arch_name()); + // add_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION, ???); + // add_kv(LLM_KV_GENERAL_ALIGNMENT, ???); + add_kv(LLM_KV_GENERAL_NAME, model.name); + // add_kv(LLM_KV_GENERAL_AUTHOR, ???); + // add_kv(LLM_KV_GENERAL_VERSION, ???); + // add_kv(LLM_KV_GENERAL_URL, ???); + // add_kv(LLM_KV_GENERAL_DESCRIPTION, ???); + // add_kv(LLM_KV_GENERAL_LICENSE, ???); + // add_kv(LLM_KV_GENERAL_SOURCE_URL, ???); + // add_kv(LLM_KV_GENERAL_SOURCE_HF_REPO, ???); + + add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens()); + add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); + add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer); + add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true); + add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + add_kv(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); + // add_kv(LLM_KV_TENSOR_DATA_LAYOUT, ???); + add_kv(LLM_KV_EXPERT_COUNT, hparams.n_expert); + add_kv(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + add_kv(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + add_kv(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + add_kv(LLM_KV_POOLING_TYPE, uint32_t(hparams.pooling_type)); + add_kv(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + add_kv(LLM_KV_DECODER_START_TOKEN_ID, hparams.dec_start_token_id); + add_kv(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping); + add_kv(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping); + add_kv(LLM_KV_SWIN_NORM, hparams.swin_norm); + add_kv(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers); + add_kv(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + add_kv(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + add_kv(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + + add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true); + add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true); + add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); + add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); + add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k); + add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v); + add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + add_kv(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + add_kv(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + + const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train; + + add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot); + add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train); + // add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name + add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train)); + add_kv(LLM_KV_ROPE_SCALING_FACTOR, rope_scaling_factor); + add_kv(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor); + add_kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn); + add_kv(LLM_KV_ROPE_SCALING_FINETUNED, hparams.rope_finetuned); + add_kv(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); + + // TODO: implement split file support + // add_kv(LLM_KV_SPLIT_NO, ???); + // add_kv(LLM_KV_SPLIT_COUNT, ???); + // add_kv(LLM_KV_SPLIT_TENSORS_COUNT, ???); + + add_kv(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + add_kv(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + add_kv(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + add_kv(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + add_kv(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms); + + add_kv(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + + add_kv(LLM_KV_TOKENIZER_MODEL, vocab.get_tokenizer_model()); + add_kv(LLM_KV_TOKENIZER_PRE, vocab.get_tokenizer_pre()); + add_kv(LLM_KV_TOKENIZER_LIST, tokens); + add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE, token_types); + add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, vocab.n_token_types()); + add_kv(LLM_KV_TOKENIZER_SCORES, scores); + add_kv(LLM_KV_TOKENIZER_MERGES, vocab.get_bpe_merges()); + // FIXME llama_token is type i32 but when reading in a GGUF file u32 is expected, not an issue for writing though + add_kv(LLM_KV_TOKENIZER_BOS_ID, uint32_t(vocab.token_bos())); + add_kv(LLM_KV_TOKENIZER_EOS_ID, uint32_t(vocab.token_eos())); + add_kv(LLM_KV_TOKENIZER_EOT_ID, uint32_t(vocab.token_eot())); + add_kv(LLM_KV_TOKENIZER_EOM_ID, uint32_t(vocab.token_eom())); + add_kv(LLM_KV_TOKENIZER_UNK_ID, uint32_t(vocab.token_unk())); + add_kv(LLM_KV_TOKENIZER_SEP_ID, uint32_t(vocab.token_sep())); + add_kv(LLM_KV_TOKENIZER_PAD_ID, uint32_t(vocab.token_pad())); + // add_kv(LLM_KV_TOKENIZER_CLS_ID, uint32_t(vocab.token_bos())); // deprecated + // add_kv(LLM_KV_TOKENIZER_MASK_ID, ???); + add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos()); + add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos()); + add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix()); + add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces()); + add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap()); + // add_kv(LLM_KV_TOKENIZER_HF_JSON, ???); + // add_kv(LLM_KV_TOKENIZER_RWKV, ???); + add_kv(LLM_KV_TOKENIZER_FIM_PRE_ID, uint32_t(vocab.token_fim_pre())); + add_kv(LLM_KV_TOKENIZER_FIM_SUF_ID, uint32_t(vocab.token_fim_suf())); + add_kv(LLM_KV_TOKENIZER_FIM_MID_ID, uint32_t(vocab.token_fim_mid())); + add_kv(LLM_KV_TOKENIZER_FIM_PAD_ID, uint32_t(vocab.token_fim_pad())); + add_kv(LLM_KV_TOKENIZER_FIM_REP_ID, uint32_t(vocab.token_fim_rep())); + add_kv(LLM_KV_TOKENIZER_FIM_SEP_ID, uint32_t(vocab.token_fim_sep())); + + // TODO: implement LoRA support + // add_kv(LLM_KV_ADAPTER_TYPE, ???); + // add_kv(LLM_KV_ADAPTER_LORA_ALPHA, ???); + + // deprecated + // add_kv(LLM_KV_TOKENIZER_PREFIX_ID, ???); + // add_kv(LLM_KV_TOKENIZER_SUFFIX_ID, ???); + // add_kv(LLM_KV_TOKENIZER_MIDDLE_ID, ???); +} + +void llama_model_saver::add_tensors_from_model() { + if (std::string(model.output->name) != std::string(model.tok_embd->name)) { + add_tensor(model.tok_embd); // some models use the same tensor for tok_embd and output + } + add_tensor(model.type_embd); + add_tensor(model.pos_embd); + add_tensor(model.tok_norm); + add_tensor(model.tok_norm_b); + add_tensor(model.output_norm); + add_tensor(model.output_norm_b); + add_tensor(model.output); + add_tensor(model.output_b); + add_tensor(model.output_norm_enc); + add_tensor(model.cls); + add_tensor(model.cls_b); + add_tensor(model.cls_out); + add_tensor(model.cls_out_b); + + for (const struct llama_layer & layer : model.layers) { + for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { + add_tensor(reinterpret_cast(&layer)[i]); + } + } +} + +void llama_model_saver::save(const std::string & path_model) { + gguf_write_to_file(gguf_ctx, path_model.c_str(), false); +} + diff --git a/examples/talk-llama/llama-model-saver.h b/examples/talk-llama/llama-model-saver.h new file mode 100644 index 00000000000..a5a434c3069 --- /dev/null +++ b/examples/talk-llama/llama-model-saver.h @@ -0,0 +1,37 @@ +#pragma once + +#include "llama.h" +#include "llama-arch.h" + +#include + +struct llama_model_saver { + struct gguf_context * gguf_ctx = nullptr; + const struct llama_model & model; + const struct LLM_KV llm_kv; + + llama_model_saver(const struct llama_model & model); + ~llama_model_saver(); + + void add_kv(enum llm_kv key, uint32_t value); + void add_kv(enum llm_kv key, int32_t value); + void add_kv(enum llm_kv key, float value); + void add_kv(enum llm_kv key, bool value); + void add_kv(enum llm_kv key, const char * value); + + [[noreturn]] + void add_kv(enum llm_kv key, char value); // needed to make the template below compile + + template + void add_kv(enum llm_kv key, const Container & value, bool per_layer = false); + + void add_kv(enum llm_kv key, const std::vector & value); + + void add_tensor(const struct ggml_tensor * tensor); + + void add_kv_from_model(); + + void add_tensors_from_model(); + + void save(const std::string & path_model); +}; diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 51092a128c5..3a4e72a36b0 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -80,6 +80,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_236B: return "236B"; case LLM_TYPE_290B: return "290B"; case LLM_TYPE_314B: return "314B"; + case LLM_TYPE_405B: return "405B"; case LLM_TYPE_671B: return "671B"; case LLM_TYPE_SMALL: return "0.1B"; case LLM_TYPE_MEDIUM: return "0.4B"; @@ -116,6 +117,10 @@ static const std::map LLAMA_ROPE_SCALING_ { LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" }, }; +std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type) { + return LLAMA_ROPE_SCALING_TYPES.at(rope_scaling_type); +} + static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) { for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) { if (kv.second == name) { @@ -298,6 +303,10 @@ static buft_list_t make_cpu_buft_list(const std::vector & de // add extra buffer types, only if no GPU device is present // ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094 auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); @@ -582,6 +591,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_layer) { case 32: type = LLM_TYPE_7B; break; case 80: type = LLM_TYPE_70B; break; + case 162: type = LLM_TYPE_405B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -773,6 +783,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // fall through case LLM_ARCH_QWEN2: { + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; @@ -1481,6 +1492,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0); const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1); auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { @@ -1648,8 +1662,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { std::regex pattern(overrides->pattern); if (std::regex_search(tensor_name, pattern)) { - LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft)); buft = overrides->buft; + LLAMA_LOG_DEBUG("tensor %s (%zu MiB %s) buffer type overridden to %s\n", + tensor_name.c_str(), + ggml_nbytes(t_meta) / 1024 / 1024, ggml_type_name(t_meta->type), + ggml_backend_buft_name(buft)); break; } } @@ -1666,6 +1683,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto * buft_dev = ggml_backend_buft_get_device(buft); if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) { auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + throw std::runtime_error("no CPU backend found"); + } buft = ggml_backend_dev_buffer_type(cpu_dev); } @@ -1847,7 +1867,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + if (n_ff > 0) { + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); @@ -1857,9 +1879,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + if (n_ff > 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } // optional MLP bias layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); @@ -3503,7 +3527,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4108,6 +4136,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!dev) { // FIXME: workaround for CPU backend buft having a NULL device dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!dev) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } } ggml_backend_dev_props props; ggml_backend_dev_get_props(dev, &props); @@ -4237,7 +4268,7 @@ uint64_t llama_model::n_elements() const { } void llama_model::print_info() const { - const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train); + const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train); auto print_f = [](const std::function & f, uint32_t n) { bool is_var = false; @@ -4298,7 +4329,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); - LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type); + LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); @@ -4445,6 +4476,19 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const { return it->second; } +ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const { + // choose long/short freq factors based on the context size + if (layers[il].rope_freqs != nullptr) { + return layers[il].rope_freqs; + } + + if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) { + return layers[il].rope_long; + } + + return layers[il].rope_short; +} + struct llm_build_llama : public llm_graph_context { llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -4485,7 +4529,7 @@ struct llm_build_llama : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -4691,6 +4735,7 @@ struct llm_build_deci : public llm_graph_context { ggml_tensor * inpSA = inpL; const int64_t n_head_kv = hparams.n_head_kv(il); const int64_t n_head = hparams.n_head(il); + const int64_t n_ff = hparams.n_ff(il); if (n_head == 0) { // attention-free layer of Llama-3_1-Nemotron-51B @@ -4710,7 +4755,7 @@ struct llm_build_deci : public llm_graph_context { } else if (n_head > 0) { // self-attention // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -4766,6 +4811,11 @@ struct llm_build_deci : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } + // FFN-free layer of Llama-3_1-Nemotron-Ultra-253B + if (n_ff == 0) { + continue; + } + // For Granite architecture if (hparams.f_residual_scale) { cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); @@ -7192,7 +7242,7 @@ struct llm_build_phi3 : public llm_graph_context { // self-attention { // rope freq factors for 128k context - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); ggml_tensor* attn_norm_output = build_norm(inpL, model.layers[il].attn_norm, @@ -7944,7 +7994,7 @@ struct llm_build_minicpm3 : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // norm cur = build_norm(inpL, @@ -8711,7 +8761,7 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto kv_head = kv_self->head; @@ -9012,7 +9062,7 @@ struct llm_build_cohere2 : public llm_graph_context { // self-attention { // rope freq factors for 128k context - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -9950,7 +10000,7 @@ struct llm_build_deepseek : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -11314,7 +11364,7 @@ struct llm_build_exaone : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -11459,7 +11509,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -11855,7 +11905,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { ggml_tensor *& first_layer_value, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -12695,7 +12745,7 @@ struct llm_build_bailingmoe : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -12815,36 +12865,46 @@ struct llm_build_bailingmoe : public llm_graph_context { } }; -llama_memory_i * llama_model::create_memory() const { +llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; switch (arch) { + case LLM_ARCH_BERT: + case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_NOMIC_BERT_MOE: + { + res = nullptr; + } break; case LLM_ARCH_MAMBA: case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6QWEN2: case LLM_ARCH_RWKV7: case LLM_ARCH_ARWKV7: { - res = new llama_kv_cache_unified(hparams, { - /*.get_rope_factors =*/ nullptr - }); + res = new llama_kv_cache_recurrent( + *this, + GGML_TYPE_F32, + GGML_TYPE_F32, + cparams.offload_kqv, + std::max((uint32_t) 1, cparams.n_seq_max)); } break; default: { - res = new llama_kv_cache_unified(hparams, { - /*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) { - // choose long/short freq factors based on the context size - if (layers[il].rope_freqs != nullptr) { - return layers[il].rope_freqs; - } + const auto padding = llama_kv_cache_unified::get_padding(cparams); - if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) { - return layers[il].rope_long; - } + cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); - return layers[il].rope_short; - } - }); + LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + + res = new llama_kv_cache_unified( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.n_ctx, + padding); } } @@ -13226,8 +13286,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DECI: case LLM_ARCH_BAICHUAN: case LLM_ARCH_STARCODER: - case LLM_ARCH_PLAMO: - case LLM_ARCH_ORION: case LLM_ARCH_INTERNLM2: case LLM_ARCH_MINICPM: case LLM_ARCH_XVERSE: @@ -13265,6 +13323,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_PHI2: case LLM_ARCH_PHI3: case LLM_ARCH_PHIMOE: + case LLM_ARCH_PLAMO: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: @@ -13272,6 +13331,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: case LLM_ARCH_CODESHELL: + case LLM_ARCH_ORION: case LLM_ARCH_NEMOTRON: case LLM_ARCH_EXAONE: case LLM_ARCH_MINICPM3: @@ -13344,6 +13404,14 @@ const char * llama_model_chat_template(const llama_model * model, const char * n : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE); const auto & it = model->gguf_kv.find(key); if (it == model->gguf_kv.end()) { + // one-off fix for very popular models (so we are not flooded with issues) + // do not extend this list unless absolutely necessary + // Mistral-Small-2503 does not have built-in chat template + llama_vocab_pre_type pre_type = model->vocab.get_pre_type(); + if (pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) { + return "mistral-v7-tekken"; + } + return nullptr; } diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 34aac337cff..6bdec263b70 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -76,6 +76,7 @@ enum llm_type { LLM_TYPE_236B, LLM_TYPE_290B, LLM_TYPE_314B, + LLM_TYPE_405B, LLM_TYPE_671B, LLM_TYPE_SMALL, LLM_TYPE_MEDIUM, @@ -95,6 +96,8 @@ enum llm_type { LLM_TYPE_235B_A22B, }; +std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type); + struct llama_layer_posnet { // resnet struct ggml_tensor * norm1 = nullptr; @@ -395,8 +398,11 @@ struct llama_model { const struct ggml_tensor * get_tensor(const char * name) const; + ggml_tensor * get_rope_factors(uint32_t n_ctx_per_seq, int il) const; + + // note: can mutate `cparams` // TODO: move this to new llm_arch_model_i interface - llama_memory_i * create_memory() const; // TODO: params + llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const; // TODO: move this to new llm_arch_model_i interface llm_graph_result_ptr build_graph( diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 7dc54227631..820d5128e29 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -519,7 +519,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: nthread = std::thread::hardware_concurrency(); } - // mmap consistently increases speed Linux, and also increases speed on Windows with + // mmap consistently increases speed on Linux, and also increases speed on Windows with // hot cache. It may cause a slowdown on macOS, possibly related to free memory. #if defined(__linux__) || defined(_WIN32) constexpr bool use_mmap = true; @@ -529,7 +529,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: llama_model_kv_override * kv_overrides = nullptr; if (params->kv_overrides) { - auto v = (std::vector*)params->kv_overrides; + auto * v = (std::vector*)params->kv_overrides; kv_overrides = v->data(); } diff --git a/examples/talk-llama/llama-sampling.cpp b/examples/talk-llama/llama-sampling.cpp index c0a5f9340d5..804b11e0a94 100644 --- a/examples/talk-llama/llama-sampling.cpp +++ b/examples/talk-llama/llama-sampling.cpp @@ -1750,23 +1750,35 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; + if (ctx->n <= 0.0f || cur_p->size <= 1) { + return; + } + // find max logit and calculate mean float max = cur_p->data[0].logit; float logits_sum = 0; + size_t valid_count = 0; for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].logit > max) { - max = cur_p->data[i].logit; + // Only count non-negative infinity values + if (cur_p->data[i].logit != -INFINITY) { + if (cur_p->data[i].logit > max) { + max = cur_p->data[i].logit; + } + logits_sum += cur_p->data[i].logit; + valid_count++; } - logits_sum += cur_p->data[i].logit; } - float mean = logits_sum/cur_p->size; + float mean = valid_count > 0 ? logits_sum/valid_count : 0; // calculate standard deviation float acc = 0; for (size_t i = 0; i < cur_p->size; ++i) { - acc += pow(cur_p->data[i].logit - mean, 2); + // Skip -infinity in std calculation + if (cur_p->data[i].logit != -INFINITY) { + acc += pow(cur_p->data[i].logit - mean, 2); + } } - float std = sqrt(acc/cur_p->size); + float std = valid_count > 0 ? sqrt(acc/valid_count) : 0; //apply mask for (size_t i = 0; i < cur_p->size; ++i) { diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index 50ded286f3f..9389ca805a5 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -1,5 +1,7 @@ #include "llama-vocab.h" +#include "ggml.h" +#include "gguf.h" #include "llama-impl.h" #include "llama-model-loader.h" @@ -415,6 +417,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_SEED_CODER: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\r\n]+|\\s*[\r\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -1227,6 +1236,9 @@ struct fragment_buffer_variant { struct llama_vocab::impl { uint32_t n_token_types = 0; // for BERT-style token types + std::string tokenizer_model; + std::string tokenizer_pre; + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; @@ -1362,9 +1374,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // determine vocab type { - std::string tokenizer_model; - std::string tokenizer_pre; - ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); @@ -1459,7 +1468,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); if (precompiled_charsmap_keyidx != -1) { - size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); + const gguf_type pc_type = gguf_get_arr_type(ctx, precompiled_charsmap_keyidx); + GGML_ASSERT(pc_type == GGUF_TYPE_INT8 || pc_type == GGUF_TYPE_UINT8); + + const size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); #ifdef IS_BIG_ENDIAN @@ -1634,6 +1646,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "bailingmoe") { pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; clean_spaces = false; + } else if ( + tokenizer_pre == "seed-coder") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER; + clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -2778,6 +2794,14 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) { pimpl->load(ml, kv); } +std::string llama_vocab::get_tokenizer_model() const { + return pimpl->tokenizer_model; +} + +std::string llama_vocab::get_tokenizer_pre() const { + return pimpl->tokenizer_pre; +} + enum llama_vocab_type llama_vocab::get_type() const { return pimpl->type; } @@ -3000,6 +3024,20 @@ int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string return it->second; } +std::vector llama_vocab::get_bpe_merges() const { + std::vector result(pimpl->bpe_ranks.size()); + + for (const auto & pair : pimpl->bpe_ranks) { + result[pair.second] = pair.first.first + " " + pair.first.second; + } + + return result; +} + +std::vector llama_vocab::get_precompiled_charsmap() const { + return pimpl->precompiled_charsmap; +} + int32_t llama_vocab::tokenize( const char * text, int32_t text_len, diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index 5ce35521434..daa6cf3082f 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -21,6 +21,9 @@ struct llama_vocab { void load(llama_model_loader & ml, const LLM_KV & kv); + std::string get_tokenizer_model() const; + std::string get_tokenizer_pre() const; + enum llama_vocab_type get_type() const; enum llama_vocab_pre_type get_pre_type() const; @@ -80,6 +83,9 @@ struct llama_vocab { int max_token_len() const; int find_bpe_rank(const std::string & token_left, const std::string & token_right) const; + std::vector get_bpe_merges() const; + + std::vector get_precompiled_charsmap() const; int32_t tokenize( const char * text, diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index c52bbe5f280..9fdddf7b071 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -4,6 +4,7 @@ #include "llama-mmap.h" #include "llama-vocab.h" #include "llama-model-loader.h" +#include "llama-model-saver.h" #include "llama-model.h" #include "ggml.h" @@ -16,6 +17,10 @@ #include #include +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + // // interface implementation // @@ -249,6 +254,13 @@ struct llama_model * llama_model_load_from_splits( return llama_model_load_from_file_impl(splits.front(), splits, params); } +void llama_model_save_to_file(const struct llama_model * model, const char * path_model) { + llama_model_saver ms(*model); + ms.add_kv_from_model(); + ms.add_tensors_from_model(); + ms.save(path_model); +} + // // chat templates // @@ -334,3 +346,4 @@ const char * llama_print_system_info(void) { return s.c_str(); } + diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 06c56395c13..99e5fba244f 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -4,6 +4,7 @@ #include "ggml.h" #include "ggml-cpu.h" #include "ggml-backend.h" +#include "ggml-opt.h" #include #include @@ -112,6 +113,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, + LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, }; enum llama_rope_type { @@ -343,7 +345,7 @@ extern "C" { float yarn_beta_fast; // YaRN low correction dim float yarn_beta_slow; // YaRN high correction dim uint32_t yarn_orig_ctx; // YaRN original context size - float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default) + float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default) ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; @@ -351,19 +353,18 @@ extern "C" { enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] - // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. - // TODO: move at the end of the struct - bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) - bool embeddings; // if true, extract embeddings (together with logits) - bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU - bool flash_attn; // whether to use flash attention [EXPERIMENTAL] - bool no_perf; // whether to measure performance timings - // Abort callback // if it returns true, execution of llama_decode() will be aborted // currently works only with CPU execution ggml_abort_callback abort_callback; void * abort_callback_data; + + // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. + bool embeddings; // if true, extract embeddings (together with logits) + bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU + bool flash_attn; // whether to use flash attention [EXPERIMENTAL] + bool no_perf; // whether to measure performance timings + bool op_offload; // whether to offload host tensor operations to device }; // model quantization parameters @@ -445,6 +446,10 @@ extern "C" { size_t n_paths, struct llama_model_params params); + LLAMA_API void llama_model_save_to_file( + const struct llama_model * model, + const char * path_model); + DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model), "use llama_model_free instead"); @@ -924,14 +929,19 @@ extern "C" { // Frees a batch of tokens allocated with llama_batch_init() LLAMA_API void llama_batch_free(struct llama_batch batch); - // Processes a batch of tokens with the ecoder part of the encoder-decoder model. - // Stores the encoder output internally for later use by the decoder cross-attention layers. + // Process a batch of tokens. + // In contrast to llama_decode() - this call does not use KV cache. + // For encode-decoder contexts, processes the batch using the encoder. + // Can store the encoder output internally for later use by the decoder's cross-attention layers. // 0 - success // < 0 - error. the KV cache state is restored to the state before this call LLAMA_API int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch); + // Process a batch of tokens. + // Requires KV cache. + // For encode-decoder contexts, processes the batch using the decoder. // Positive return values does not mean a fatal error, but rather a warning. // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) @@ -1428,6 +1438,37 @@ extern "C" { LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); + // + // training + // + + // function that returns whether or not a given tensor contains trainable parameters + typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata); + + // always returns true + LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata); + + struct llama_opt_params { + uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0 + + llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters + void * param_filter_ud; // userdata for determining which tensors contain trainable parameters + + ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters + void * get_opt_pars_ud; // userdata for calculating optimizer parameters + }; + + LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params); + + LLAMA_API void llama_opt_epoch( + struct llama_context * lctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval); + #ifdef __cplusplus } #endif