diff --git a/bindings/ruby/ext/options.rb b/bindings/ruby/ext/options.rb index be63de108a6..679b74d133a 100644 --- a/bindings/ruby/ext/options.rb +++ b/bindings/ruby/ext/options.rb @@ -142,6 +142,7 @@ def configure bool "GGML_RV_ZFH" pending "GGML_SCCACHE_FOUND" string "GGML_SCHED_MAX_COPIES" + bool "GGML_SSE42" ignored "GGML_STATIC" bool "GGML_SYCL" string "GGML_SYCL_DEVICE_ARCH" diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index d33f843b417..61fe15a15f0 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -107,6 +107,7 @@ message(DEBUG "INS_ENB : ${INS_ENB}") option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF) option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON) option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF) +option(GGML_SSE42 "ggml: enable SSE 4.2" ${INS_ENB}) option(GGML_AVX "ggml: enable AVX" ${INS_ENB}) option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF) option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB}) @@ -170,7 +171,6 @@ option(GGML_HIP "ggml: use HIP" option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) -option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF) option(GGML_VULKAN "ggml: use Vulkan" OFF) option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF) option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF) diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 4e0d210f8ec..c8b6097f7e5 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -7,6 +7,9 @@ extern "C" { #endif +#define RPC_PROTO_MAJOR_VERSION 1 +#define RPC_PROTO_MINOR_VERSION 0 +#define RPC_PROTO_PATCH_VERSION 0 #define GGML_RPC_MAX_SERVERS 16 // backend API diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 452c967b0a6..51aa5b3a0ab 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -481,6 +481,7 @@ extern "C" { GGML_OP_CONV_TRANSPOSE_1D, GGML_OP_IM2COL, GGML_OP_IM2COL_BACK, + GGML_OP_CONV_2D_DW, GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_POOL_1D, GGML_OP_POOL_2D, @@ -507,17 +508,12 @@ extern "C" { GGML_OP_UNARY, - GGML_OP_MAP_UNARY, - GGML_OP_MAP_BINARY, - - GGML_OP_MAP_CUSTOM1_F32, - GGML_OP_MAP_CUSTOM2_F32, - GGML_OP_MAP_CUSTOM3_F32, - GGML_OP_MAP_CUSTOM1, GGML_OP_MAP_CUSTOM2, GGML_OP_MAP_CUSTOM3, + GGML_OP_CUSTOM, + GGML_OP_CROSS_ENTROPY_LOSS, GGML_OP_CROSS_ENTROPY_LOSS_BACK, GGML_OP_OPT_STEP_ADAMW, @@ -682,6 +678,9 @@ extern "C" { GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1 GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2 + // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN + GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor); + GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1); GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1); @@ -1665,7 +1664,7 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); - // depthwise + // depthwise (via im2col and mul_mat) GGML_API struct ggml_tensor * ggml_conv_2d_dw( struct ggml_context * ctx, struct ggml_tensor * a, // convolution kernel @@ -1677,6 +1676,22 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 + // Depthwise 2D convolution + // may be faster than ggml_conv_2d_dw, but not available in all backends + // a: KW KH 1 C convolution kernel + // b: W H C N input data + // res: W_out H_out C N + GGML_API struct ggml_tensor * ggml_conv_2d_dw_direct( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int stride0, + int stride1, + int pad0, + int pad1, + int dilation0, + int dilation1); + GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1722,24 +1737,29 @@ extern "C" { float p0, float p1); - // nearest interpolate + enum ggml_scale_mode { + GGML_SCALE_MODE_NEAREST = 0, + GGML_SCALE_MODE_BILINEAR = 1, + }; + + // interpolate // multiplies ne0 and ne1 by scale factor - // used in stable-diffusion GGML_API struct ggml_tensor * ggml_upscale( struct ggml_context * ctx, struct ggml_tensor * a, - int scale_factor); + int scale_factor, + enum ggml_scale_mode mode); - // nearest interpolate - // nearest interpolate to specified dimensions - // used in tortoise.cpp + // interpolate + // interpolate scale to specified dimensions GGML_API struct ggml_tensor * ggml_upscale_ext( struct ggml_context * ctx, struct ggml_tensor * a, int ne0, int ne1, int ne2, - int ne3); + int ne3, + enum ggml_scale_mode mode); // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0] GGML_API struct ggml_tensor * ggml_pad( @@ -1916,83 +1936,6 @@ extern "C" { // custom operators - typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *); - typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *); - - typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *); - typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); - typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_unary_op_f32_t fun), - "use ggml_map_custom1 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_unary_op_f32_t fun), - "use ggml_map_custom1_inplace instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_binary_op_f32_t fun), - "use ggml_map_custom2 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_binary_op_f32_t fun), - "use ggml_map_custom2_inplace instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_custom1_op_f32_t fun), - "use ggml_map_custom1 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_custom1_op_f32_t fun), - "use ggml_map_custom1_inplace instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_custom2_op_f32_t fun), - "use ggml_map_custom2 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_custom2_op_f32_t fun), - "use ggml_map_custom2_inplace instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - ggml_custom3_op_f32_t fun), - "use ggml_map_custom3 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - ggml_custom3_op_f32_t fun), - "use ggml_map_custom3_inplace instead"); - - // custom operators v2 - typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata); typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata); @@ -2048,6 +1991,30 @@ extern "C" { int n_tasks, void * userdata); + typedef void (*ggml_custom_op_t)(struct ggml_tensor * dst , int ith, int nth, void * userdata); + + GGML_API struct ggml_tensor * ggml_custom_4d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + struct ggml_tensor ** args, + int n_args, + ggml_custom_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor ** args, + int n_args, + ggml_custom_op_t fun, + int n_tasks, + void * userdata); + // loss function GGML_API struct ggml_tensor * ggml_cross_entropy_loss( diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index f00700da71f..43d9fc4fe25 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -267,6 +267,7 @@ function(ggml_add_cpu_backend_variant tag_name) set(GGML_CPU_TAG_NAME ${tag_name}) # other: OPENMP LLAMAFILE CPU_HBM foreach (feat NATIVE + SSE42 AVX AVX2 BMI2 AVX_VNNI FMA F16C AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8 AMX_BF16) @@ -286,14 +287,16 @@ if (GGML_CPU_ALL_VARIANTS) if (NOT GGML_BACKEND_DL) message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL") endif() - ggml_add_cpu_backend_variant(sandybridge AVX) - ggml_add_cpu_backend_variant(haswell AVX F16C AVX2 BMI2 FMA) - ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 BMI2 FMA AVX512) - ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI) - ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 BMI2 FMA AVX_VNNI) + ggml_add_cpu_backend_variant(x64) + ggml_add_cpu_backend_variant(sse42 SSE42) + ggml_add_cpu_backend_variant(sandybridge SSE42 AVX) + ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C AVX2 BMI2 FMA) + ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512) + ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI) + ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI) if (NOT MSVC) # MSVC doesn't support AMX - ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8) + ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8) endif() elseif (GGML_CPU) ggml_add_cpu_backend_variant_impl("") diff --git a/ggml/src/ggml-cann/acl_tensor.cpp b/ggml/src/ggml-cann/acl_tensor.cpp index d120ce6acf8..f5462c5a18e 100644 --- a/ggml/src/ggml-cann/acl_tensor.cpp +++ b/ggml/src/ggml-cann/acl_tensor.cpp @@ -41,6 +41,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) { return ACL_INT4; case GGML_TYPE_Q8_0: return ACL_INT8; + case GGML_TYPE_I64: + return ACL_INT64; default: return ACL_DT_UNDEFINED; } @@ -54,9 +56,7 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne, // added. int64_t acl_ne[GGML_MAX_DIMS * 2], acl_stride[GGML_MAX_DIMS * 2]; - int64_t acl_storage_len = 0; if (ne == nullptr) { - acl_storage_len = ggml_nbytes(tensor); for (int i = 0; i < GGML_MAX_DIMS; i++) { acl_ne[i] = tensor->ne[i]; // The step size of acl is in elements. @@ -65,14 +65,18 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne, } else { // With bcast for (int i = 0; i < dims; i++) { - acl_storage_len += (ne[i] - 1) * nb[i]; acl_ne[i] = ne[i]; acl_stride[i] = nb[i] / ggml_element_size(tensor); } } - // Reverse ne and stride. int64_t final_dims = (dims == 0 ? GGML_MAX_DIMS : dims); + int64_t acl_storage_len = 1; + for (int i = 0; i < final_dims; i++) { + acl_storage_len += (acl_ne[i] - 1) * acl_stride[i]; + } + + // Reverse ne and stride. std::reverse(acl_ne, acl_ne + final_dims); std::reverse(acl_stride, acl_stride + final_dims); diff --git a/ggml/src/ggml-cann/acl_tensor.h b/ggml/src/ggml-cann/acl_tensor.h index 4734a9cb8c3..93f09937efb 100644 --- a/ggml/src/ggml-cann/acl_tensor.h +++ b/ggml/src/ggml-cann/acl_tensor.h @@ -101,14 +101,14 @@ aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype, tmp_stride[i] = nb[i] / type_size; } - std::reverse(tmp_ne, tmp_ne + dims); - std::reverse(tmp_stride, tmp_stride + dims); - - int64_t acl_storage_len = 0; + int64_t acl_storage_len = 1; for (int i = 0; i < dims; i++) { - acl_storage_len += (ne[i] - 1) * nb[i]; + acl_storage_len += (tmp_ne[i] - 1) * tmp_stride[i]; } + std::reverse(tmp_ne, tmp_ne + dims); + std::reverse(tmp_stride, tmp_stride + dims); + aclTensor* acl_tensor = aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size, format, &acl_storage_len, 1, data_ptr); diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 8482bb53761..67c0223c010 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -28,7 +28,6 @@ #include #include #include -#include #include #include #include @@ -45,12 +44,27 @@ #include #include #include -#include #include #include #include #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include @@ -64,6 +78,34 @@ #include "../ggml-common.h" +void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0, + aclTensor ** acl_src1, aclTensor ** acl_dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0)); + // Need bcast + if (!ggml_are_same_shape(src0, src1) && ggml_cann_need_bcast(src0, src1)) { + BCAST_SHAPE(src0, src1) + *acl_src0 = ggml_cann_create_tensor(src0, BCAST_PARAM(src0)); + *acl_src1 = ggml_cann_create_tensor(src1, BCAST_PARAM(src1)); + *acl_dst = ggml_cann_create_tensor(dst, BCAST_PARAM(src0)); + } else { + *acl_src0 = ggml_cann_create_tensor(src0); + *acl_src1 = ggml_cann_create_tensor(src1); + *acl_dst = ggml_cann_create_tensor(dst); + } +} + +void ggml_cann_unary_op( + std::function unary_op, + ggml_backend_cann_context& ctx, ggml_tensor* dst) { + ggml_tensor* src = dst->src[0]; + + aclTensor* acl_src = ggml_cann_create_tensor(src); + aclTensor* acl_dst = ggml_cann_create_tensor(dst); + + unary_op(ctx, acl_src, acl_dst); + ggml_cann_release_resources(ctx, acl_src, acl_dst); +} + /** * @brief Repeats elements of a tensor along each dimension according to the * specified repeat array. @@ -79,53 +121,26 @@ static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src, // repeat tensor along each dim with repeat_array aclIntArray* repeats = aclCreateIntArray(repeat_array, GGML_MAX_DIMS); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnRepeatGetWorkspaceSize(acl_src, repeats, acl_dst, - &workspaceSize, &executor)); - - if (workspaceSize > 0) { - // Memory from allocator will "free" immediately, and this memory - // will be alloced to other pointers, but it won't access before - // this async task end because all tasks in same stream will execute - // in queue. - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - ACL_CHECK( - aclnnRepeat(workspaceAddr, workspaceSize, executor, ctx.stream())); - ACL_CHECK(aclDestroyIntArray(repeats)); + GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_src, repeats, acl_dst); + ggml_cann_release_resources(ctx, repeats); } /** - * @brief Casts the elements of a tensor to a specified data type using the CANN backend. + * @brief Casts the data type of a source tensor to a destination tensor. * - * @details This function performs a type conversion on the elements of the input tensor `acl_src` - * and stores the results in the destination tensor `acl_dst`. The conversion type is - * determined based on the `dst` tensor's data type. + * This function casts the data type of the source tensor `acl_src` to the + * specified data type `cast_data_type` and stores the result in the destination + * tensor `acl_dst`. * * @param ctx The context for the CANN backend operations. - * @param acl_src The source tensor whose elements will be cast. - * @param acl_dst The destination tensor that will store the casted elements. - * @param dst The ggml tensor specifying the target data type. + * @param acl_src The source tensor whose data type will be casted. + * @param acl_dst The destination tensor where the casted result will be stored. + * @param cast_data_type The target data type to which the source tensor will be + * casted. */ static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst, ggml_tensor* dst) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - ACL_CHECK(aclnnCastGetWorkspaceSize(acl_src, - ggml_cann_type_mapping(dst->type), - acl_dst, &workspaceSize, &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnCast(workspaceAddr, workspaceSize, executor, ctx.stream())); + aclTensor* acl_dst, aclDataType cast_data_type) { + GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src, cast_data_type, acl_dst); } void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -139,73 +154,78 @@ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { dst->ne[1] / src->ne[1], dst->ne[0] / src->ne[0]}; aclnn_repeat(ctx, acl_src, acl_dst, repeatsArray); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_src, acl_dst); } -/** - * @brief Adds two tensors element-wise and stores the result in a destination - * tensor. - * - * This function performs the operation: - * \f[ - * dst = acl\_src0 + alpha \times acl\_src1 - * \f] - * where alpha is a scalar value and defaults to 1.0f. - * - * @param ctx The context for the CANN backend operations. - * @param acl_src0 The first source tensor. - * @param acl_src1 The second source tensor. - * @param acl_dst The destination tensor where the result will be stored. - */ -static void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0, +void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0, aclTensor* acl_src1, aclTensor* acl_dst) { - aclScalar* alpha = nullptr; float alphaValue = 1.0f; - alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnAddGetWorkspaceSize(acl_src0, acl_src1, alpha, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnAdd(workspaceAddr, workspaceSize, executor, ctx.stream())); + aclScalar* alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); + if (acl_dst != nullptr) + GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_src0, acl_src1, alpha, acl_dst); + else + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_src0, acl_src1, alpha); + ggml_cann_release_resources(ctx, alpha); +} - ACL_CHECK(aclDestroyScalar(alpha)); +void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0, + aclTensor* acl_src1, aclTensor* acl_dst) { + float alphaValue = 1.0f; + aclScalar* alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); + if (acl_dst != nullptr) + GGML_CANN_CALL_ACLNN_OP(ctx, Sub, acl_src0, acl_src1, alpha, acl_dst); + else + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSub, acl_src0, acl_src1, alpha); + ggml_cann_release_resources(ctx, alpha); } -void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src0 = dst->src[0]; - ggml_tensor* src1 = dst->src[1]; - GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); +void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_other, aclTensor* acl_dst) { + if (acl_dst != nullptr) + GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_src, acl_other, acl_dst); + else + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_src, acl_other); +} - aclTensor* acl_src0; - aclTensor* acl_src1; - aclTensor* acl_dst; +void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_other, aclTensor* acl_dst) { + if (acl_dst != nullptr) + GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src, acl_other, acl_dst); + else + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDiv, acl_src, acl_other); +} - // Need bcast - if (!ggml_are_same_shape(src0, src1) && ggml_cann_need_bcast(src0, src1)) { - BCAST_SHAPE(src0, src1) - acl_src0 = ggml_cann_create_tensor(src0, BCAST_PARAM(src0)); - acl_src1 = ggml_cann_create_tensor(src1, BCAST_PARAM(src1)); - acl_dst = ggml_cann_create_tensor(dst, BCAST_PARAM(src0)); +/** + * @brief Multiplies elements of a tensor by a scalar value, optionally + * in-place. + * + * This function multiplies each element of the source tensor `acl_src` by the + * scalar `scale` and stores the result in the destination tensor `acl_dst`. If + * `inplace` is true, `acl_dst` will not be used and the operation is performed + * in-place on `acl_src`. + * The operation is defined as: + * \f[ + * \text {acl_dst }_i=\text {acl_src }_i \times \text {scale} + * \f] + * + * @param ctx The context for the CANN backend operations. + * @param acl_src The source tensor whose elements will be multiplied. + * @param scale The scalar value by which each element of `acl_src` will be + * multiplied. + * @param acl_dst The destination tensor where the result will be stored if + * `inplace` is false. + * @param inplace Flag indicating whether to perform the operation in-place on + * `acl_src`. + */ +static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src, + float scale, aclTensor* acl_dst, bool inplace) { + aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); + if (inplace) { + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_src, acl_scale); } else { - acl_src0 = ggml_cann_create_tensor(src0); - acl_src1 = ggml_cann_create_tensor(src1); - acl_dst = ggml_cann_create_tensor(dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Muls, acl_src, acl_scale, acl_dst); } - - aclnn_add(ctx, acl_src0, acl_src1, acl_dst); - - ACL_CHECK(aclDestroyTensor(acl_src0)); - ACL_CHECK(aclDestroyTensor(acl_src1)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_scale); } void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -222,23 +242,8 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclScalar* acl_negative_slope = aclCreateScalar(&negative_slope, aclDataType::ACL_FLOAT); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnLeakyReluGetWorkspaceSize( - acl_src, acl_negative_slope, acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnLeakyRelu(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyScalar(acl_negative_slope)); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + GGML_CANN_CALL_ACLNN_OP(ctx, LeakyRelu, acl_src, acl_negative_slope, acl_dst); + ggml_cann_release_resources(ctx, acl_negative_slope, acl_src, acl_dst); } /** @@ -254,18 +259,7 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) { static void aclnn_concat(ggml_backend_cann_context& ctx, aclTensorList* tensorList, aclTensor* acl_dst, int64_t concat_dim) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnCatGetWorkspaceSize(tensorList, concat_dim, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnCat(workspaceAddr, workspaceSize, executor, ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(ctx, Cat, tensorList, concat_dim, acl_dst); } void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -281,11 +275,10 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { int32_t acl_dim = 3 - dim; aclTensor* tensors[] = {acl_src0, acl_src1}; - aclTensorList* tensorList = aclCreateTensorList(tensors, 2); - aclnn_concat(ctx, tensorList, acl_dst, acl_dim); + aclTensorList* tensor_list = aclCreateTensorList(tensors, 2); + aclnn_concat(ctx, tensor_list, acl_dst, acl_dim); - ACL_CHECK(aclDestroyTensorList(tensorList)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, tensor_list, acl_dst); } /** @@ -311,27 +304,12 @@ static void aclnn_arange(ggml_backend_cann_context& ctx, aclTensor* acl_dst, int64_t steps = (int64_t)std::ceil((stop - start) / step); GGML_ASSERT(n_elements == steps); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - aclScalar* acl_start = aclCreateScalar(&start, aclDataType::ACL_FLOAT); aclScalar* acl_end = aclCreateScalar(&stop, aclDataType::ACL_FLOAT); aclScalar* acl_step = aclCreateScalar(&step, aclDataType::ACL_FLOAT); - ACL_CHECK(aclnnArangeGetWorkspaceSize(acl_start, acl_end, acl_step, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnArange(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyScalar(acl_start)); - ACL_CHECK(aclDestroyScalar(acl_end)); - ACL_CHECK(aclDestroyScalar(acl_step)); + GGML_CANN_CALL_ACLNN_OP(ctx, Arange, acl_start, acl_end, acl_step, acl_dst); + ggml_cann_release_resources(ctx, acl_start, acl_end, acl_step); } void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -348,18 +326,11 @@ void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst) { memcpy(&step, (float*)dst->op_params + 2, sizeof(float)); aclnn_arange(ctx, acl_dst, start, stop, step, n_elements); - ACL_CHECK(aclDestroyTensor(acl_dst)); -} - -void ggml_cann_sqr(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - dst->src[1] = dst->src[0]; - ggml_cann_mul_div(ctx, dst); + ggml_cann_release_resources(ctx, acl_dst); } void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src = dst->src[0]; - GGML_ASSERT(src->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); float min; float max; @@ -372,23 +343,8 @@ void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclScalar* acl_min = aclCreateScalar(&min, aclDataType::ACL_FLOAT); aclScalar* acl_max = aclCreateScalar(&max, aclDataType::ACL_FLOAT); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnClampGetWorkspaceSize(acl_src, acl_min, acl_max, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnClamp(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyScalar(acl_min)); - ACL_CHECK(aclDestroyScalar(acl_max)); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + GGML_CANN_CALL_ACLNN_OP(ctx, Clamp, acl_src, acl_min, acl_max, acl_dst); + ggml_cann_release_resources(ctx, acl_min, acl_max, acl_src, acl_dst); } void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -402,22 +358,8 @@ void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* acl_src = ggml_cann_create_tensor(src); aclTensor* acl_dst = ggml_cann_create_tensor(dst); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnMulsGetWorkspaceSize(acl_src, scale, acl_dst, &workspaceSize, - &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnMuls(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyScalar(scale)); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + GGML_CANN_CALL_ACLNN_OP(ctx, Muls, acl_src, scale, acl_dst); + ggml_cann_release_resources(ctx, scale, acl_src, acl_dst); } void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -432,36 +374,10 @@ void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* tmp_tensor = ggml_cann_create_tensor(buffer, ACL_INT64, ggml_type_size(dst->type), dst->ne, dst->nb, GGML_MAX_DIMS); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnArgsortGetWorkspaceSize( - acl_src, -1, (order == GGML_SORT_ORDER_DESC ? true : false), tmp_tensor, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnArgsort(workspaceAddr, workspaceSize, executor, ctx.stream())); - - workspaceSize = 0; - ACL_CHECK(aclnnCastGetWorkspaceSize(tmp_tensor, - ggml_cann_type_mapping(dst->type), - acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnCast(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(tmp_tensor)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + GGML_CANN_CALL_ACLNN_OP(ctx, Argsort, acl_src, -1, (order == GGML_SORT_ORDER_DESC ? true : false), + tmp_tensor); + GGML_CANN_CALL_ACLNN_OP(ctx, Cast, tmp_tensor, ggml_cann_type_mapping(dst->type), acl_dst); + ggml_cann_release_resources(ctx, acl_src, tmp_tensor, acl_dst); } void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -473,27 +389,11 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - std::vector normData = {dst->ne[0]}; aclIntArray* norm = aclCreateIntArray(normData.data(), normData.size()); - ACL_CHECK(aclnnLayerNormGetWorkspaceSize(acl_src, norm, nullptr, nullptr, - eps, acl_dst, nullptr, nullptr, - &workspaceSize, &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnLayerNorm(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyIntArray(norm)); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + GGML_CANN_CALL_ACLNN_OP(ctx, LayerNorm, acl_src, norm, nullptr, nullptr, + eps, acl_dst, nullptr, nullptr); + ggml_cann_release_resources(ctx, norm, acl_src, acl_dst); } void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -507,10 +407,6 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { float eps; memcpy(&eps, dst->op_params + 1, sizeof(float)); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - int64_t N = src->ne[3]; int64_t C = src->ne[2]; int64_t HxW = src->ne[1] * src->ne[0]; @@ -527,22 +423,9 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* acl_rstd_out = ggml_cann_create_tensor( (char*)buffer + n_bytes, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND); - ACL_CHECK(aclnnGroupNormGetWorkspaceSize( - acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps, acl_dst, - acl_mean_out, acl_rstd_out, &workspaceSize, &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnGroupNorm(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - ACL_CHECK(aclDestroyTensor(acl_mean_out)); - ACL_CHECK(aclDestroyTensor(acl_rstd_out)); + GGML_CANN_CALL_ACLNN_OP(ctx, GroupNorm, acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps, + acl_dst, acl_mean_out, acl_rstd_out); + ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_mean_out, acl_rstd_out); } void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -565,68 +448,52 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) { float alphaValue = 1.0f; alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - if (!inplace) { size_t cpy_size = ggml_nbytes(dst); - ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + ggml_cann_async_memcpy(ctx, dst->data, src0->data, cpy_size, + ACL_MEMCPY_DEVICE_TO_DEVICE); aclTensor* acl_src0 = ggml_cann_create_tensor( src0, src1->ne, src0->nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset); - ACL_CHECK(aclnnAddGetWorkspaceSize(acl_src0, acl_src1, alpha, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - ACL_CHECK( - aclnnAdd(workspaceAddr, workspaceSize, executor, ctx.stream())); - ACL_CHECK(aclDestroyTensor(acl_src0)); + + GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_src0, acl_src1, alpha, acl_dst); + ggml_cann_release_resources(ctx, acl_src0); } else { - ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, acl_src1, alpha, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - ACL_CHECK(aclnnInplaceAdd(workspaceAddr, workspaceSize, executor, - ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst, acl_src1, alpha); } - - ACL_CHECK(aclDestroyTensor(acl_src1)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_src1, acl_dst); } -void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { +/** + * @brief Performs sum reduction on a given tensor along specified dimensions. + * + * This function reduces the input tensor by summing along the specified dimensions. + * + * @param ctx The context for the CANN backend operations. + * @param dst The destination tensor where the reduced result will be stored. + * @param dim An array of dimension indices. + * @param dim_size The number of dimensions. + */ +static void aclnn_reduce_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst, + int64_t* dim, size_t dim_size) { + GGML_ASSERT(dst->ne[0] == 1); ggml_tensor* src = dst->src[0]; - aclTensor* acl_src = ggml_cann_create_tensor(src); - - GGML_ASSERT(dst->ne[0] == 1); aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclIntArray* reduce_dims = aclCreateIntArray(dim, dim_size); - int64_t reduce_dims_host[] = {3}; - aclIntArray* reduce_dims = aclCreateIntArray(reduce_dims_host, 1); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnReduceSumGetWorkspaceSize( - acl_src, reduce_dims, true, ggml_cann_type_mapping(src->type), acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } + GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_src, reduce_dims, true, + ggml_cann_type_mapping(dst->type), acl_dst); + ggml_cann_release_resources(ctx, acl_src, acl_dst, reduce_dims); +} - ACL_CHECK( - aclnnReduceSum(workspaceAddr, workspaceSize, executor, ctx.stream())); +void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + int64_t reduce_dims[] = {3}; + aclnn_reduce_sum(ctx, dst, reduce_dims, 1); +} - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); +void ggml_cann_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + int64_t reduce_dims[] = {0, 1, 2, 3}; + aclnn_reduce_sum(ctx, dst, reduce_dims, 4); } void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx, @@ -640,23 +507,8 @@ void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx, std::vector output_size{dst->ne[1], dst->ne[0]}; auto output_size_array = aclCreateIntArray(output_size.data(), 2); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnUpsampleNearest2dGetWorkspaceSize( - acl_src, output_size_array, acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnUpsampleNearest2d(workspaceAddr, workspaceSize, executor, - ctx.stream())); - - ACL_CHECK(aclDestroyIntArray(output_size_array)); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + GGML_CANN_CALL_ACLNN_OP(ctx, UpsampleNearest2d, acl_src, output_size_array, acl_dst); + ggml_cann_release_resources(ctx, acl_src, acl_dst, output_size_array); } /** @@ -679,23 +531,8 @@ static void aclnn_pad(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclIntArray* acl_pad = aclCreateIntArray(paddings, GGML_MAX_DIMS * 2); aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnConstantPadNdGetWorkspaceSize( - acl_src, acl_pad, acl_value, acl_dst, &workspaceSize, &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnConstantPadNd(workspaceAddr, workspaceSize, executor, - ctx.stream())); - - ACL_CHECK(aclDestroyIntArray(acl_pad)); - ACL_CHECK(aclDestroyScalar(acl_value)); + GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_src, acl_pad, acl_value, acl_dst); + ggml_cann_release_resources(ctx, acl_pad, acl_value); } void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -711,9 +548,7 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) { 0, dst->ne[0] - src->ne[0], 0, dst->ne[1] - src->ne[1], 0, dst->ne[2] - src->ne[2], 0, dst->ne[3] - src->ne[3]}; aclnn_pad(ctx, acl_src, acl_dst, paddings); - - ACL_CHECK(aclDestroyTensor(acl_dst)); - ACL_CHECK(aclDestroyTensor(acl_src)); + ggml_cann_release_resources(ctx, acl_src, acl_dst); } /** @@ -759,28 +594,15 @@ static void ggml_cann_avg_pool2d(ggml_backend_cann_context& ctx, bool count_include_pad = true; int64_t divisor_override = 0; int8_t cube_math_type = 0; +#ifdef ASCEND_310P + cube_math_type = 1; +#endif - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnAvgPool2dGetWorkspaceSize( - acl_src, kernel_size, strides, paddings_avg, ceil_mode, - count_include_pad, divisor_override, cube_math_type, acl_dst, - &workspaceSize, &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - ACL_CHECK( - aclnnAvgPool2d(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - ACL_CHECK(aclDestroyIntArray(kernel_size)); - ACL_CHECK(aclDestroyIntArray(strides)); - ACL_CHECK(aclDestroyIntArray(paddings_avg)); + GGML_CANN_CALL_ACLNN_OP(ctx, AvgPool2d, acl_src, kernel_size, strides, paddings_avg, + ceil_mode, count_include_pad, divisor_override, + cube_math_type, acl_dst); + ggml_cann_release_resources(ctx, acl_src, acl_dst, kernel_size, strides, + paddings_avg); } /** @@ -848,29 +670,10 @@ static void ggml_cann_max_pool2d(ggml_backend_cann_context& ctx, bool ceil_mode = false; int64_t auto_pads = 0; - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnMaxPoolGetWorkspaceSize( - tmp_tensor, kernel_size, strides, auto_pads, paddings_max, dilations, - ceil_mode, acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnMaxPool(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - ACL_CHECK(aclDestroyTensor(tmp_tensor)); - ACL_CHECK(aclDestroyIntArray(kernel_size)); - ACL_CHECK(aclDestroyIntArray(strides)); - ACL_CHECK(aclDestroyIntArray(paddings_max)); - ACL_CHECK(aclDestroyIntArray(dilations)); + GGML_CANN_CALL_ACLNN_OP(ctx, MaxPool, tmp_tensor, kernel_size, strides, auto_pads, + paddings_max, dilations, ceil_mode, acl_dst); + ggml_cann_release_resources(ctx, acl_src, acl_dst, tmp_tensor, kernel_size, + strides, paddings_max, dilations); } void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -901,20 +704,7 @@ void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) { */ static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnInplaceCopyGetWorkspaceSize(acl_dst, acl_src, &workspaceSize, - &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnInplaceCopy(workspaceAddr, workspaceSize, executor, ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst, acl_src); } void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -926,15 +716,14 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { if (dst->type == src0->type) { cann_copy(ctx, acl_src, acl_dst); } else { - aclnn_cast(ctx, acl_src, acl_dst, dst); + aclnn_cast(ctx, acl_src, acl_dst, ggml_cann_type_mapping(dst->type)); } } else { if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) { if (dst->type == src0->type) { size_t cpy_size = ggml_nbytes(dst); - ACL_CHECK(aclrtMemcpyAsync( - dst->data, cpy_size, src0->data, cpy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + ggml_cann_async_memcpy(ctx, dst->data, src0->data, cpy_size, + ACL_MEMCPY_DEVICE_TO_DEVICE); return; } else { ggml_cann_pool_alloc src_buffer_allocator( @@ -951,12 +740,11 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src, src_trans_tensor, dst); + aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type)); size_t cpy_size = ggml_nbytes(dst); - ACL_CHECK(aclrtMemcpyAsync( - dst->data, cpy_size, src_trans_buffer, cpy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); - ACL_CHECK(aclDestroyTensor(src_trans_tensor)); + ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size, + ACL_MEMCPY_DEVICE_TO_DEVICE); + ggml_cann_release_resources(ctx, src_trans_tensor); return; } } else if (ggml_is_contiguous(dst)) { @@ -973,37 +761,19 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src, src_trans_tensor, dst); + aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type)); size_t cpy_size = ggml_nbytes(dst); - ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, src_trans_buffer, - cpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, - ctx.stream())); - ACL_CHECK(aclDestroyTensor(src_trans_tensor)); + ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size, + ACL_MEMCPY_DEVICE_TO_DEVICE); + ggml_cann_release_resources(ctx, src_trans_tensor); return; } else { GGML_ABORT("Unsupport dst is not tontiguous."); } } - - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); -} - -#ifdef __cplusplus -extern "C" { -#endif -aclnnStatus aclnnRmsNormGetWorkspaceSize(const aclTensor* x, - const aclTensor* gamma, double epsilon, - const aclTensor* yOut, - const aclTensor* rstdOout, - uint64_t* workspaceSize, - aclOpExecutor** executor); -aclnnStatus aclnnRmsNorm(void* workspace, uint64_t workspaceSize, - aclOpExecutor* executor, aclrtStream stream); -#ifdef __cplusplus + ggml_cann_release_resources(ctx, acl_src, acl_dst); } -#endif /** * @brief Creates an ACL tensor initialized with zeros using a provided buffer. @@ -1030,7 +800,7 @@ static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer, nb[i] = nb[i - 1] * ne[i - 1]; } - ACL_CHECK(aclrtMemsetAsync(buffer, n_bytes, 0, n_bytes, ctx.stream())); + ggml_cann_async_memset(ctx, buffer, n_bytes, 0); aclTensor* zero = ggml_cann_create_tensor(buffer, type, type_size, ne, nb, dims); return zero; @@ -1063,21 +833,7 @@ static aclTensor* aclnn_values(ggml_backend_cann_context& ctx, void* buffer, float alpha_host = 1.0f; aclScalar* alpha = aclCreateScalar(&alpha_host, aclDataType::ACL_FLOAT); aclScalar* other = aclCreateScalar(&value, aclDataType::ACL_FLOAT); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnInplaceAddsGetWorkspaceSize(acl_tensor, other, alpha, - &workspaceSize, &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - ACL_CHECK( - aclnnInplaceAdds(workspaceAddr, workspaceSize, executor, ctx.stream())); - + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_tensor, other, alpha); return acl_tensor; } @@ -1089,13 +845,6 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); - - GGML_ASSERT(eps > 0.0f); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - size_t one_tensor_n_bytes = src->ne[0] * ggml_element_size(src); ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes); @@ -1110,22 +859,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclnn_zero(ctx, zero_tensor_allocator.get(), zero_tensor_n_bytes, src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type), ggml_element_size(src)); - - ACL_CHECK(aclnnRmsNormGetWorkspaceSize( - acl_src, acl_gamma, eps, acl_dst, acl_rstd, &workspaceSize, &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnRmsNorm(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - ACL_CHECK(aclDestroyTensor(acl_gamma)); - ACL_CHECK(aclDestroyTensor(acl_rstd)); + GGML_CANN_CALL_ACLNN_OP(ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd); + ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_gamma, acl_rstd); } // TODO: performace is low. @@ -1147,75 +882,14 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type), ggml_element_size(src), value); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnInplaceTriuGetWorkspaceSize(mask_tensor, n_past + 1, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnInplaceTriu(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclnnTrilGetWorkspaceSize(acl_src, n_past + 1, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnTril(workspaceAddr, workspaceSize, executor, ctx.stream())); - aclScalar* alpha = nullptr; float alphaValue = 1.0f; alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); - ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, mask_tensor, alpha, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - ACL_CHECK( - aclnnInplaceAdd(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyScalar(alpha)); - ACL_CHECK(aclDestroyTensor(mask_tensor)); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); -} - -/** - * @brief Casts the data type of a source tensor to a destination tensor. - * - * This function casts the data type of the source tensor `acl_src` to the - * specified data type `cast_data_type` and stores the result in the destination - * tensor `acl_dst`. - * - * @param ctx The context for the CANN backend operations. - * @param acl_src The source tensor whose data type will be casted. - * @param acl_dst The destination tensor where the casted result will be stored. - * @param cast_data_type The target data type to which the source tensor will be - * casted. - */ -static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst, aclDataType cast_data_type) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnCastGetWorkspaceSize(acl_src, cast_data_type, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnCast(workspaceAddr, workspaceSize, executor, ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceTriu, mask_tensor, n_past + 1); + GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src, n_past + 1, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst, mask_tensor, alpha); + ggml_cann_release_resources(ctx, alpha, acl_src, acl_dst, mask_tensor); } /** @@ -1236,39 +910,9 @@ static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src, static void aclnn_permute(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, int64_t* new_dim, uint64_t dims) { aclIntArray* acl_dims = aclCreateIntArray(new_dim, dims); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnPermuteGetWorkspaceSize(acl_src, acl_dims, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnPermute(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyIntArray(acl_dims)); -} - -#ifdef __cplusplus -extern "C" { -#endif -aclnnStatus aclnnIm2colGetWorkspaceSize(const aclTensor* self, - const aclIntArray* kernelSize, - const aclIntArray* dilation, - const aclIntArray* padding, - const aclIntArray* stride, - aclTensor* out, uint64_t* workspaceSize, - aclOpExecutor** executor); -aclnnStatus aclnnIm2col(void* workspace, uint64_t workspaceSize, - aclOpExecutor* executor, aclrtStream stream); -#ifdef __cplusplus + GGML_CANN_CALL_ACLNN_OP(ctx, Permute, acl_src, acl_dims, acl_dst); + ggml_cann_release_resources(ctx, acl_dims); } -#endif static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx, ggml_tensor* dst, @@ -1288,8 +932,7 @@ static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx, aclnn_permute(ctx, tmp_im2col_tensor, acl_dst, permute_dim, 3); } - // release - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_dst); } static void ggml_cann_im2col_1d_post_process( @@ -1311,7 +954,6 @@ static void ggml_cann_im2col_1d_post_process( // Permute: [N, IC * KH * KW, OW * OH] -> // [N, OW * OH * n_bytes_factor, IC * KH * KW] - aclTensor* tmp_permute_tensor = nullptr; ggml_cann_pool_alloc tmp_permute_allocator(ctx.pool()); tmp_permute_allocator.alloc(ggml_nbytes(dst) * n_bytes_factor); void* tmp_permute_buffer = tmp_permute_allocator.get(); @@ -1323,7 +965,7 @@ static void ggml_cann_im2col_1d_post_process( tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1]; } - tmp_permute_tensor = ggml_cann_create_tensor( + aclTensor* tmp_permute_tensor = ggml_cann_create_tensor( tmp_permute_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), tmp_permute_ne, tmp_permute_nb, GGML_MAX_DIMS - 1, ACL_FORMAT_ND); @@ -1353,9 +995,8 @@ static void ggml_cann_im2col_1d_post_process( c * KH * KW * n_step_w * ggml_type_size(dst->type); for (int i = 0; i < n_step_w; i++) { - ACL_CHECK(aclrtMemcpyAsync( - cur_dst_buffer, size_cpy, cur_permute_buffer, size_cpy, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + ggml_cann_async_memcpy(ctx, cur_dst_buffer, cur_permute_buffer, size_cpy, + ACL_MEMCPY_DEVICE_TO_DEVICE); cur_dst_buffer = (char*)cur_dst_buffer + KH * KW * ggml_type_size(dst->type); cur_permute_buffer = (char*)cur_permute_buffer + @@ -1365,13 +1006,11 @@ static void ggml_cann_im2col_1d_post_process( } else { offset = KH * KW * n_step_w * ggml_type_size(dst->type); // equal to ggml_nbytes(dst) - ACL_CHECK(aclrtMemcpyAsync(dst->data, offset, - (char*)tmp_permute_buffer + offset, offset, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + ggml_cann_async_memcpy(ctx, dst->data, (char*)tmp_permute_buffer + offset, offset, + ACL_MEMCPY_DEVICE_TO_DEVICE); } - // release - ACL_CHECK(aclDestroyTensor(tmp_permute_tensor)); + ggml_cann_release_resources(ctx, tmp_permute_tensor); } void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -1433,23 +1072,8 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { auto* dilations = aclCreateIntArray(dilation_size.data(), 2); auto* paddings = aclCreateIntArray(padding_dims.data(), 2); auto* strides = aclCreateIntArray(stride_dims.data(), 2); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnIm2colGetWorkspaceSize(acl_src1, kernel_size, dilations, - paddings, strides, tmp_im2col_tensor, - &workspaceSize, &executor)); - - ggml_cann_pool_alloc workspace_allocator(ctx.pool()); - if (workspaceSize > 0) { - workspace_allocator.alloc(workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnIm2col(workspaceAddr, workspaceSize, executor, ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(ctx, Im2col, acl_src1, kernel_size, dilations, + paddings, strides, tmp_im2col_tensor); // Cast if dst is f16. aclTensor* tmp_cast_tensor = nullptr; @@ -1461,331 +1085,56 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { size_t temp_cast_nb[GGML_MAX_DIMS - 1]; temp_cast_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS - 1; i++) { - temp_cast_nb[i] = temp_cast_nb[i - 1] * tmp_im2col_ne[i - 1]; - } - - tmp_cast_tensor = ggml_cann_create_tensor( - tmp_cast_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), tmp_im2col_ne, temp_cast_nb, - GGML_MAX_DIMS - 1, ACL_FORMAT_ND); - aclnn_cast(ctx, tmp_im2col_tensor, tmp_cast_tensor, - ggml_cann_type_mapping(dst->type)); - } - - // post-processing - if (is_2D) { - ggml_cann_im2col_2d_post_process(ctx, dst, src1, tmp_cast_tensor, - tmp_im2col_tensor); - } else { - std::vector im2col_op_params = { - KH, KW, IW, IC, N, OH, OW, s0, p0, d0, n_bytes_factor}; - ggml_cann_im2col_1d_post_process(ctx, dst, src1, tmp_cast_tensor, - tmp_im2col_tensor, im2col_op_params); - } - - // release - ACL_CHECK(aclDestroyTensor(acl_src1)); - ACL_CHECK(aclDestroyTensor(tmp_im2col_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_cast_tensor)); - ACL_CHECK(aclDestroyIntArray(kernel_size)); - ACL_CHECK(aclDestroyIntArray(dilations)); - ACL_CHECK(aclDestroyIntArray(paddings)); - ACL_CHECK(aclDestroyIntArray(strides)); -} - -/** - * @brief Applies element-wise exponential function to the elements of a tensor. - * - * This function computes the exponential of each element in the source tensor - * `acl_src` and stores the result back into the same tensor. - * The operation is defined as: - * \f[ - * \text {acl_src }_i=e^{acl\_src_i} - * \f] - * - * @param ctx The context for the CANN backend operations. - * @param acl_src The tensor on which the exponential function will be applied. - */ -static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK( - aclnnInplaceExpGetWorkspaceSize(acl_src, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnInplaceExp(workspaceAddr, workspaceSize, executor, ctx.stream())); -} - -/** - * @brief Multiplies elements of a tensor by a scalar value, optionally - * in-place. - * - * This function multiplies each element of the source tensor `acl_src` by the - * scalar `scale` and stores the result in the destination tensor `acl_dst`. If - * `inplace` is true, `acl_dst` will not be used and the operation is performed - * in-place on `acl_src`. - * The operation is defined as: - * \f[ - * \text {acl_dst }_i=\text {acl_src }_i \times \text {scale} - * \f] - * - * @param ctx The context for the CANN backend operations. - * @param acl_src The source tensor whose elements will be multiplied. - * @param scale The scalar value by which each element of `acl_src` will be - * multiplied. - * @param acl_dst The destination tensor where the result will be stored if - * `inplace` is false. - * @param inplace Flag indicating whether to perform the operation in-place on - * `acl_src`. - */ -static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src, - float scale, aclTensor* acl_dst, bool inplace) { - aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - if (inplace) { - ACL_CHECK(aclnnInplaceMulsGetWorkspaceSize(acl_src, acl_scale, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnInplaceMuls(workspaceAddr, workspaceSize, executor, - ctx.stream())); - } else { - ACL_CHECK(aclnnMulsGetWorkspaceSize(acl_src, acl_scale, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnMuls(workspaceAddr, workspaceSize, executor, ctx.stream())); - } - - ACL_CHECK(aclDestroyScalar(acl_scale)); -} - -/** - * @brief Performs an in-place element-wise multiplication of two tensors. - * - * This function performs an element-wise multiplication of the tensors - * `acl_src` and `acl_other` and stores the result in `acl_src`. - * The operation is defined as: - * \f[ - * \text {acl_src }_i=\text {acl_src }_i \times \text {acl_other }_i - * \f] - * - * @param ctx The context for the CANN backend operations. - * @param acl_src The source tensor where the multiplication result will be - * stored. - * @param acl_other The tensor whose elements will be multiplied with `acl_src`. - */ -static void aclnn_inplace_mul(ggml_backend_cann_context& ctx, - aclTensor* acl_src, aclTensor* acl_other) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnInplaceMulGetWorkspaceSize(acl_src, acl_other, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnInplaceMul(workspaceAddr, workspaceSize, executor, ctx.stream())); -} + temp_cast_nb[i] = temp_cast_nb[i - 1] * tmp_im2col_ne[i - 1]; + } -/** - * @brief Performs element-wise multiplication of two tensors and stores the - * result in a destination tensor. - * - * This function performs element-wise multiplication of the tensors `acl_src` - * and `acl_other` and stores the result in the destination tensor `acl_dst`. - * The operation is defined as: - * \f[ - * \text {acl_dst }_i=\text {acl_src }_i \times \text {acl_other }_i - * \f] - * - * @param ctx The context for the CANN backend operations. - * @param acl_src The first tensor for element-wise multiplication. - * @param acl_other The second tensor for element-wise multiplication. - * @param acl_dst The destination tensor where the result will be stored. - */ -static void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_other, aclTensor* acl_dst) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnMulGetWorkspaceSize(acl_src, acl_other, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); + tmp_cast_tensor = ggml_cann_create_tensor( + tmp_cast_buffer, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), tmp_im2col_ne, temp_cast_nb, + GGML_MAX_DIMS - 1, ACL_FORMAT_ND); + aclnn_cast(ctx, tmp_im2col_tensor, tmp_cast_tensor, ggml_cann_type_mapping(dst->type)); } - ACL_CHECK(aclnnMul(workspaceAddr, workspaceSize, executor, ctx.stream())); -} - -/** - * @brief Applies element-wise cosine function to the elements of a tensor. - * - * This function computes the cosine of each element in the source tensor - * `acl_src` and stores the result in the destination tensor `acl_dst`. The - * operation is defined as: \f[ \text {acl_dst }_i=\cos \left(\text {acl_src - * }_i\right) \f] - * - * @param ctx The context for the CANN backend operations. - * @param acl_src The source tensor on which the cosine function will be - * applied. - * @param acl_dst The destination tensor where the cosine results will be - * stored. - */ -static void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK( - aclnnCosGetWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); + // post-processing + if (is_2D) { + ggml_cann_im2col_2d_post_process(ctx, dst, src1, tmp_cast_tensor, + tmp_im2col_tensor); + } else { + std::vector im2col_op_params = { + KH, KW, IW, IC, N, OH, OW, s0, p0, d0, n_bytes_factor}; + ggml_cann_im2col_1d_post_process(ctx, dst, src1, tmp_cast_tensor, + tmp_im2col_tensor, im2col_op_params); } - ACL_CHECK(aclnnCos(workspaceAddr, workspaceSize, executor, ctx.stream())); + ggml_cann_release_resources(ctx, acl_src1, tmp_im2col_tensor, tmp_cast_tensor, + kernel_size, dilations, paddings, strides); } /** - * @brief Applies element-wise sine function to the elements of a tensor. + * @brief Applies element-wise exponential function to the elements of a tensor. * - * This function computes the sine of each element in the source tensor - `acl_src` - * and stores the result in the destination tensor `acl_dst`. + * This function computes the exponential of each element in the source tensor + * `acl_src` and stores the result back into the same tensor. * The operation is defined as: * \f[ - * \text {acl_dst }_i=\sin \left(\text {acl_src }_i\right) + * \text {acl_src }_i=e^{acl\_src_i} * \f] - - * @param ctx The context for the CANN backend operations. - * @param acl_src The source tensor on which the sine function will be applied. - * @param acl_dst The destination tensor where the sine results will be stored. - */ -static void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK( - aclnnSinGetWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnSin(workspaceAddr, workspaceSize, executor, ctx.stream())); -} - -/** - * @brief Performs element-wise division of tensor1 by tensor2 , multiplies the - result by the scalar value and adds it to self . * - * Performs element-wise division of tensor1 by tensor2, - * multiplies the result by the scalar value and adds it to self . - * The operation is defined as: - * \f[ - * \text{out}_i = \text{selft}_i + \text{value} \times - \frac{\text{tensor1}_i}{\text{tensor2}_i} - * \f] - * @param ctx The context for the CANN backend operations. - * @param acl_self The source tensor on which the addcdiv function will be - applied. - * @param tensor1 Numerator tensor. - * @param tensor2 Denominator tensor. - * @param value The value to be used for coefficient. + * @param acl_src The tensor on which the exponential function will be applied. */ -static void aclnn_inplace_addcdiv(ggml_backend_cann_context& ctx, - aclTensor* acl_self, aclTensor* tensor1, - aclTensor* tensor2, float value) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT); - - ACL_CHECK(aclnnInplaceAddcdivGetWorkspaceSize( - acl_self, tensor1, tensor2, acl_value, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnInplaceAddcdiv(workspaceAddr, workspaceSize, executor, - ctx.stream())); +static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) { + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceExp, acl_src); } -/** - * @brief Matrix division, optionally in-place. - * - * This function division each element of the source tensor `acl_src` by the - * tensor `acl_other` and stores the result in the destination tensor `acl_dst`. - * If `inplace` is true, `acl_dst` will not be used and the operation is - * performed in-place on `acl_src`. The operation is defined as: \f[ - * \text{dst}_i = \frac{\text{acl_src}_i}{\text{acl_other}_i} - * \f] - * - * @param ctx The context for the CANN backend operations. - * @param acl_src Numerator tensor.. - * @param acl_other Denominator tensor. - * @param acl_dst The destination tensor where the result will be stored if - * `inplace` is false. - * @param inplace Flag indicating whether to perform the operation in-place on - * `acl_src`. - */ -static void aclnn_div_tensor(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_other, aclTensor* acl_dst, - bool inplace) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - if (inplace) { - ACL_CHECK(aclnnInplaceDivGetWorkspaceSize(acl_src, acl_other, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnInplaceDiv(workspaceAddr, workspaceSize, executor, - ctx.stream())); - } else { - ACL_CHECK(aclnnDivGetWorkspaceSize(acl_src, acl_other, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } +void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_dst) { + GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst); +} - ACL_CHECK( - aclnnDiv(workspaceAddr, workspaceSize, executor, ctx.stream())); - } +void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_dst) { + GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst); } void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, @@ -1834,13 +1183,13 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_cann_pool_alloc permute_allocator(ctx.pool(), ggml_nbytes(src)); void* tmp_permute_buffer = permute_allocator.get(); - aclTensor* tmp_permute_tenosr = ggml_cann_create_tensor( + aclTensor* tmp_permute_tensor = ggml_cann_create_tensor( tmp_permute_buffer, ggml_cann_type_mapping(src->type), ggml_type_size(src->type), tmp_permute_ne, tmp_permute_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); int64_t permute_dim[] = {0, 1, 3, 2}; int64_t num_dims = 4; - aclnn_permute(ctx, acl_src, tmp_permute_tenosr, permute_dim, num_dims); + aclnn_permute(ctx, acl_src, tmp_permute_tensor, permute_dim, num_dims); // timestep * freq int64_t tmp_mul_ne[] = {src->ne[1] * half, src->ne[0], src->ne[2], @@ -1861,7 +1210,7 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, tmp_mul_buffer, ggml_cann_type_mapping(src->type), ggml_type_size(src->type), tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); - aclnn_mul(ctx, tmp_permute_tenosr, tmp_arange_tensor, tmp_mul_tensor); + aclnn_mul(ctx, tmp_permute_tensor, tmp_arange_tensor, tmp_mul_tensor); // cos ggml_cann_pool_alloc cos_allocator( @@ -1889,17 +1238,13 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, int64_t concat_dim = 3; aclTensor* acl_dst = ggml_cann_create_tensor(dst); aclTensor* tensors[] = {tmp_cos_tensor, tmp_sin_tensor}; - aclTensorList* tensorList = aclCreateTensorList(tensors, 2); - aclnn_concat(ctx, tensorList, acl_dst, concat_dim); + aclTensorList* tensor_list = aclCreateTensorList(tensors, 2); + aclnn_concat(ctx, tensor_list, acl_dst, concat_dim); // release // segmentation fault when delete both tensorList and his elements. - ACL_CHECK(aclDestroyTensorList(tensorList)); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(tmp_arange_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_permute_tenosr)); - ACL_CHECK(aclDestroyTensor(tmp_mul_tensor)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, tensor_list, acl_src, tmp_arange_tensor, + tmp_permute_tensor, tmp_mul_tensor, acl_dst); } /** @@ -1915,21 +1260,8 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, aclTensor* acl_dst) { auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnInplaceFillScalarGetWorkspaceSize( - acl_dst, acl_scalar, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnInplaceFillScalar(workspaceAddr, workspaceSize, executor, - ctx.stream())); - ACL_CHECK(aclDestroyScalar(acl_scalar)); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar); + ggml_cann_release_resources(ctx, acl_scalar); } /** @@ -1950,19 +1282,7 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, */ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, aclTensor* acl_dst, aclTensor* acl_exp) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnInplacePowTensorTensorGetWorkspaceSize( - acl_dst, acl_exp, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnInplacePowTensorTensor(workspaceAddr, workspaceSize, - executor, ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(ctx, InplacePowTensorTensor, acl_dst, acl_exp); } /** @@ -2114,56 +1434,15 @@ static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src, // add aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst); - - ACL_CHECK(aclDestroyTensor(tmp_arange1_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_arange2_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_mk_base1_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_mk_base2_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_mk_base_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_arange_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_mk_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_output_tensor)); + ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor, + tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor, + tmp_arange_tensor, tmp_mk_tensor, tmp_output_tensor); } void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_cann_dup(ctx, dst); } -/** - * @brief Performs element-wise addition of two tensors in place. - * - * This function adds the source tensor `acl_src` to the destination tensor - * `acl_dst` element-wise and stores the result in the destination tensor - * `acl_dst`. - * - * @param ctx The context for the CANN backend operations. - * @param acl_src The source tensor to be added. - * @param acl_dst The destination tensor which will hold the result of the - * addition. - */ -static void aclnn_inplace_add(ggml_backend_cann_context& ctx, - aclTensor* acl_src, aclTensor* acl_dst) { - aclScalar* alpha = nullptr; - float alphaValue = 1.0f; - alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, acl_src, alpha, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnInplaceAdd(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyScalar(alpha)); -} - /** * @brief Applies the softmax function to a tensor along a specified dimension. * @@ -2180,20 +1459,7 @@ static void aclnn_inplace_add(ggml_backend_cann_context& ctx, */ static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src, int64_t dim, aclTensor* acl_dst) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnSoftmaxGetWorkspaceSize(acl_src, dim, acl_dst, - &workspaceSize, &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - aclrtStream stream = ctx.stream(); - ACL_CHECK(aclnnSoftmax(workspaceAddr, workspaceSize, executor, stream)); + GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst); } void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -2243,8 +1509,7 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { src1_fp32_nb, GGML_MAX_DIMS); aclTensor* acl_src1 = ggml_cann_create_tensor(src1); aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT); - - ACL_CHECK(aclDestroyTensor(acl_src1)); + ggml_cann_release_resources(ctx, acl_src1); } else { acl_src1_fp32_tensor = ggml_cann_create_tensor(src1); } @@ -2297,17 +1562,13 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // softmax aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst); - ACL_CHECK(aclDestroyTensor(alibi_output_tensor)); + ggml_cann_release_resources(ctx, alibi_output_tensor); } else { aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst); } - ACL_CHECK(aclDestroyTensor(acl_src0)); - ACL_CHECK(aclDestroyTensor(acl_src1_fp32_tensor)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - ACL_CHECK(aclDestroyScalar(acl_scale)); - ACL_CHECK(aclDestroyTensor(acl_input_mul_scale_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_mask_tensor)); + ggml_cann_release_resources(ctx, acl_src0, acl_src1_fp32_tensor, acl_dst, + acl_scale, acl_input_mul_scale_tensor, tmp_mask_tensor); } /** @@ -2354,26 +1615,8 @@ static void aclnn_embedding_4d(ggml_backend_cann_context& ctx, void* src_buffer, (char*)dst->data + i * dst->nb[3] + j * dst->nb[2], ggml_cann_type_mapping(dst->type), ggml_element_size(dst), acl_out_ne, acl_out_nb, 2); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnEmbeddingGetWorkspaceSize( - acl_src_tensor, acl_index, acl_out, &workspaceSize, &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), - workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnEmbedding(workspaceAddr, workspaceSize, executor, - ctx.stream())); - - ACL_CHECK(aclDestroyTensor(acl_src_tensor)); - ACL_CHECK(aclDestroyTensor(acl_index)); - ACL_CHECK(aclDestroyTensor(acl_out)); + GGML_CANN_CALL_ACLNN_OP(ctx, Embedding, acl_src_tensor, acl_index, acl_out); + ggml_cann_release_resources(ctx, acl_src_tensor, acl_index, acl_out); } } } @@ -2401,11 +1644,10 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* src_trans_tensor = ggml_cann_create_tensor( src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src0, src_trans_tensor, dst); + aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type)); aclnn_embedding_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, src1, dst); - ACL_CHECK(aclDestroyTensor(acl_src0)); - ACL_CHECK(aclDestroyTensor(src_trans_tensor)); + ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor); break; } case GGML_TYPE_Q8_0: { @@ -2451,7 +1693,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, GGML_MAX_DIMS + 1); aclTensor* acl_scale_tensor = ggml_cann_create_tensor( - src0->data, ACL_FLOAT16, sizeof(float16_t), scale_ne, scale_nb, + src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); aclTensor* dequant_tensor = ggml_cann_create_tensor( dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float_t), @@ -2467,7 +1709,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclnn_embedding_4d(ctx, dequant_buffer_allocator.get(), dequant_ne, dequant_nb, src1, dst); - ACL_CHECK(aclDestroyTensor(dequant_tensor)); + ggml_cann_release_resources(ctx, dequant_tensor); break; } default: @@ -2495,133 +1737,8 @@ static void aclnn_repeat_interleave(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, int64_t dim, int64_t repeats, int64_t output_size) { - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnRepeatInterleaveIntWithDimGetWorkspaceSize( - acl_src, repeats, dim, output_size, acl_dst, &workspaceSize, - &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnRepeatInterleaveIntWithDim(workspaceAddr, workspaceSize, - executor, ctx.stream())); -} - -/** - * @brief Performs matrix multiplication of two tensors. - * - * This function computes the matrix multiplication of the input tensor - * `acl_input` and the weight tensor `acl_weight`, and stores the result in the - * destination tensor `acl_dst`. - * The operation is defined as: - * \f[ - * \text {acl_dst}=\text {acl_input@acl_weight} - * \f] - * - * @param ctx The context for the CANN backend operations. - * @param acl_input The input tensor for the matrix multiplication. - * @param acl_weight The weight tensor for the matrix multiplication. - * @param acl_dst The destination tensor where the result of the matrix - * multiplication will be stored. - */ -static void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input, - aclTensor* acl_weight, aclTensor* acl_dst) { - int8_t cube_math_type = 1; // ALLOW_FP32_DOWN_PRECISION, when input is - // fp32, atlas a2 will transpose it to HFLOAT32. - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnMatmulGetWorkspaceSize(acl_input, acl_weight, acl_dst, - cube_math_type, &workspaceSize, - &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnMatmul(workspaceAddr, workspaceSize, executor, ctx.stream())); -} - -/** - * @brief Performs matrix multiplication of two 2D tensors. - * - * This function computes the matrix multiplication of the input tensor - * `acl_input` and the weight tensor `acl_weight`, and stores the result in the - * destination tensor `acl_dst`. - * The operation is defined as: - * \f[ - * \text {acl_dst}=\text {acl_input@acl_weight} - * \f] - * - * @param ctx The context for the CANN backend operations. - * @param acl_input The input tensor for the matrix multiplication. - * @param acl_weight The weight tensor for the matrix multiplication. - * @param acl_dst The destination tensor where the result of the matrix - * multiplication will be stored. - */ -static void aclnn_mat_mul_2d(ggml_backend_cann_context& ctx, - aclTensor* acl_input, aclTensor* acl_weight, - aclTensor* acl_dst) { - int8_t cube_math_type = 2; - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnMmGetWorkspaceSize(acl_input, acl_weight, acl_dst, - cube_math_type, &workspaceSize, - &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnMm(workspaceAddr, workspaceSize, executor, ctx.stream())); -} - -/** - * @brief Performs matrix multiplication of two 3D tensors. - * - * This function computes the matrix multiplication of the input tensor - * `acl_input` and the weight tensor `acl_weight`, and stores the result in the - * destination tensor `acl_dst`. - * The operation is defined as: - * \f[ - * \text {acl_dst}=\text {acl_input@acl_weight} - * \f] - * - * @param ctx The context for the CANN backend operations. - * @param acl_input The input tensor for the matrix multiplication. - * @param acl_weight The weight tensor for the matrix multiplication. - * @param acl_dst The destination tensor where the result of the matrix - * multiplication will be stored. - */ -static void aclnn_mat_mul_3d(ggml_backend_cann_context& ctx, - aclTensor* acl_input, aclTensor* acl_weight, - aclTensor* acl_dst) { - int8_t cube_math_type = 2; - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnBatchMatMulGetWorkspaceSize(acl_input, acl_weight, acl_dst, - cube_math_type, &workspaceSize, - &executor)); - - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK( - aclnnBatchMatMul(workspaceAddr, workspaceSize, executor, ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(ctx, RepeatInterleaveIntWithDim, acl_src, repeats, dim, + output_size, acl_dst); } /** @@ -2669,19 +1786,19 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx, switch (n_dims) { case 2: - aclnn_mat_mul_2d(ctx, acl_input_tensor, acl_weight_tensor, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Mm, acl_input_tensor, acl_weight_tensor, acl_dst, 2); break; case 3: - aclnn_mat_mul_3d(ctx, acl_input_tensor, acl_weight_tensor, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, BatchMatMul, acl_input_tensor, acl_weight_tensor, acl_dst, 2); break; default: - aclnn_mat_mul(ctx, acl_input_tensor, acl_weight_tensor, acl_dst); + // ALLOW_FP32_DOWN_PRECISION, when input is + // fp32, atlas a2 will transpose it to HFLOAT32. + GGML_CANN_CALL_ACLNN_OP(ctx, Matmul, acl_input_tensor, acl_weight_tensor, acl_dst, 1); break; } - ACL_CHECK(aclDestroyTensor(acl_weight_tensor)); - ACL_CHECK(aclDestroyTensor(acl_input_tensor)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_weight_tensor, acl_input_tensor, acl_dst); } /** @@ -2751,9 +1868,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, input_buffer, ACL_FLOAT16, input_elem_size, input_cast_ne, input_cast_nb, GGML_MAX_DIMS); aclnn_cast(ctx, acl_src1_tensor, acl_input_tensor, ACL_FLOAT16); - - ACL_CHECK(aclDestroyTensor(acl_input_tensor)); - ACL_CHECK(aclDestroyTensor(acl_src1_tensor)); + ggml_cann_release_resources(ctx, acl_input_tensor, acl_src1_tensor); } // output @@ -2768,9 +1883,6 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, int64_t max_elem_size = 65535; int64_t split_size = (src0->ne[1] / max_elem_size) + 1; ggml_cann_pool_alloc workspace_allocator(ctx.pool()); - aclOpExecutor* executor = nullptr; - uint64_t workspaceSize = 0; - void* workspaceAddr = nullptr; for (int64_t n1 = 0; n1 < src1->ne[3]; n1++) { for (int64_t c1 = 0; c1 < src1->ne[2]; c1++) { int64_t n0 = n1 / (src1->ne[3] / src0->ne[3]); @@ -2809,20 +1921,11 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, if (src0->ne[0] > QK8_0) { antiquantGroupSize = QK8_0; } - - ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize( - acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr, - nullptr, nullptr, nullptr, antiquantGroupSize, - acl_output_tensor, &workspaceSize, &executor)); - if (workspaceAddr == nullptr) { - workspaceAddr = workspace_allocator.alloc(workspaceSize); - } - ACL_CHECK(aclnnWeightQuantBatchMatmulV2( - workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyTensor(acl_weight_tensor)); - ACL_CHECK(aclDestroyTensor(acl_scale_tensor)); - ACL_CHECK(aclDestroyTensor(acl_output_tensor)); + GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, acl_input_tensor, + acl_weight_tensor, acl_scale_tensor, nullptr, + nullptr, nullptr, nullptr, antiquantGroupSize, + acl_output_tensor); + ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, acl_output_tensor); // other splits for (int64_t split = 1; split < split_size; split++) { @@ -2849,20 +1952,14 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16, output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND, output_ne_offset); - - ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize( - acl_input_tensor, acl_weight_tensor, acl_scale_tensor, - nullptr, nullptr, nullptr, nullptr, antiquantGroupSize, - acl_output_tensor, &workspaceSize, &executor)); - ACL_CHECK(aclnnWeightQuantBatchMatmulV2( - workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyTensor(acl_weight_tensor)); - ACL_CHECK(aclDestroyTensor(acl_scale_tensor)); - ACL_CHECK(aclDestroyTensor(acl_output_tensor)); + GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, acl_input_tensor, + acl_weight_tensor, acl_scale_tensor, nullptr, + nullptr, nullptr, nullptr, antiquantGroupSize, + acl_output_tensor); + ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, acl_output_tensor); } - ACL_CHECK(aclDestroyTensor(acl_input_tensor)); + ggml_cann_release_resources(ctx, acl_input_tensor); } } @@ -2879,11 +1976,9 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, output_buffer, ACL_FLOAT16, output_elem_size, output_cast_ne, output_cast_nb, GGML_MAX_DIMS); aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); - aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, - ggml_cann_type_mapping(dst->type)); + aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); - ACL_CHECK(aclDestroyTensor(acl_output_tensor)); - ACL_CHECK(aclDestroyTensor(acl_dst_tensor)); + ggml_cann_release_resources(ctx, acl_output_tensor, acl_dst_tensor); } } @@ -2899,7 +1994,7 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_cann_mul_mat_quant(ctx, dst, type); break; default: - GGML_ABORT("fatal error"); + GGML_ABORT("Unsupported type for mul_mat"); break; } } @@ -2924,22 +2019,8 @@ static void aclnn_roll(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, int64_t* shifts, int64_t* dims) { aclIntArray* acl_shifts = aclCreateIntArray(shifts, 1); aclIntArray* acl_dims = aclCreateIntArray(dims, 1); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnRollGetWorkspaceSize(acl_src, acl_shifts, acl_dims, acl_dst, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnRoll(workspaceAddr, workspaceSize, executor, ctx.stream())); - - ACL_CHECK(aclDestroyIntArray(acl_shifts)); - ACL_CHECK(aclDestroyIntArray(acl_dims)); + GGML_CANN_CALL_ACLNN_OP(ctx, Roll, acl_src, acl_shifts, acl_dims, acl_dst); + ggml_cann_release_resources(ctx, acl_shifts, acl_dims); } /** @@ -2961,23 +2042,8 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx, float value) { aclIntArray* acl_index = aclCreateIntArray(index, index_num); aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT); - - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - ACL_CHECK(aclnnInplaceIndexFillTensorGetWorkspaceSize( - acl_src, dim, acl_index, acl_value, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); - } - - ACL_CHECK(aclnnInplaceIndexFillTensor(workspaceAddr, workspaceSize, - executor, ctx.stream())); - - ACL_CHECK(aclDestroyIntArray(acl_index)); - ACL_CHECK(aclDestroyScalar(acl_value)); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexFillTensor, acl_src, dim, acl_index, acl_value); + ggml_cann_release_resources(ctx, acl_index, acl_value); } static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, @@ -2992,37 +2058,30 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_tensor* src1 = dst->src[1]; // position ggml_tensor* src2 = dst->src[2]; // freq_factors - // arange, [0,1,...,ne0/2] - int64_t arange_length = src0->ne[0] / 2; - ggml_cann_pool_alloc arange_allocator(ctx.pool(), - arange_length * sizeof(float_t)); - void* arange_buffer = arange_allocator.get(); - int64_t arange_ne[] = {arange_length, 1, 1, 1}; - size_t arange_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t), - arange_length * sizeof(float_t)}; - - aclTensor* acl_arange_tensor = - ggml_cann_create_tensor(arange_buffer, ACL_FLOAT, sizeof(float_t), - arange_ne, arange_nb, GGML_MAX_DIMS); + GGML_TENSOR_BINARY_OP_LOCALS + + // theta_scale arange, [0,1,...,ne00/2 - 1] + int64_t theta_scale_length = ne00 / 2; + ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(), + theta_scale_length * sizeof(float_t)); + void* theta_scale_buffer = theta_scale_allocator.get(); + int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1}; + size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t), + theta_scale_length * sizeof(float_t)}; + + aclTensor* acl_theta_scale_tensor = + ggml_cann_create_tensor(theta_scale_buffer, ACL_FLOAT, sizeof(float_t), + theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); float start = 0; float step = 1; - float stop = src0->ne[0] / 2; - float n_elements = src0->ne[0] / 2; - aclnn_arange(ctx, acl_arange_tensor, start, stop, step, n_elements); + float stop = ne00 / 2; + float n_elements = ne00 / 2; + aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements); // power - // aclnnPowScalarTensor(): @param self is tensor which should be scalar, so - // use aclnn_pow_tensor_tensor() until fixed. aclScalar* acl_theta_scale = - // aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT); - // aclnn_power_scalar_tensor(ctx, acl_theta_scale, acl_arange_tensor, - // acl_power_tensor); - ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(), - arange_length * sizeof(float_t)); - void* theta_scale_buffer = theta_scale_allocator.get(); - aclTensor* acl_theta_scale_tensor = aclnn_values( - ctx, theta_scale_buffer, arange_length * sizeof(float_t), arange_ne, - GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), theta_scale); - aclnn_pow_tensor_tensor(ctx, acl_theta_scale_tensor, acl_arange_tensor); + aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor, + acl_theta_scale_tensor); // freq_scale if (freq_scale != 1) { @@ -3033,29 +2092,27 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, if (src2) { aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor( src2->data, ggml_cann_type_mapping(src2->type), - ggml_type_size(src2->type), arange_ne, arange_nb, GGML_MAX_DIMS); - aclnn_div_tensor(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor, - nullptr, true); - ACL_CHECK(aclDestroyTensor(acl_freq_factors_tensor)); + ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor); + ggml_cann_release_resources(ctx, acl_freq_factors_tensor); } // position GGML_ASSERT(src1->type == GGML_TYPE_I32); int64_t position_length = src1->ne[0]; - int64_t position_ne[] = {1, position_length, 1, 1}; - size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t), - sizeof(int32_t) * position_length, + int64_t position_ne[] = {1, 1, position_length, 1}; + size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length}; aclTensor* acl_position_tensor = ggml_cann_create_tensor( src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS); // power * position - int64_t theta_length = arange_length * position_length; + int64_t theta_length = theta_scale_length * position_length; ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float_t)); void* theta_buffer = theta_allocator.get(); - int64_t theta_ne[] = {arange_length, position_length, 1, 1}; + int64_t theta_ne[] = {theta_scale_length, 1, position_length, 1}; size_t theta_nb[GGML_MAX_DIMS]; theta_nb[0] = sizeof(float_t); for (int i = 1; i < GGML_MAX_DIMS; i++) { @@ -3067,40 +2124,22 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor, acl_theta_tensor); - // permute: [0,1,2,3]->[0,2,1,3] - int64_t permute_ne[] = {arange_length, 1, position_length, 1}; - size_t permute_nb[GGML_MAX_DIMS]; - permute_nb[0] = sizeof(float_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - permute_nb[i] = permute_nb[i - 1] * permute_ne[i - 1]; - } - ggml_cann_pool_alloc permute_allocator(ctx.pool(), - theta_length * sizeof(float_t)); - void* permute_buffer = permute_allocator.get(); - aclTensor* acl_permute_tensor = ggml_cann_create_tensor( - permute_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb, - GGML_MAX_DIMS, ACL_FORMAT_ND); - int64_t permute_dim[] = {0, 2, 1, 3}; - int64_t num_dims = 4; - aclnn_permute(ctx, acl_theta_tensor, acl_permute_tensor, permute_dim, - num_dims); - // sin/cos ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float_t)); void* sin_buffer = sin_allocator.get(); aclTensor* acl_sin_tensor = ggml_cann_create_tensor( - sin_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb, + sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); - aclnn_sin(ctx, acl_permute_tensor, acl_sin_tensor); + aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor); ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float_t)); void* cos_buffer = cos_allocator.get(); aclTensor* acl_cos_tensor = ggml_cann_create_tensor( - cos_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb, + cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); - aclnn_cos(ctx, acl_permute_tensor, acl_cos_tensor); + aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor); // attn_factor if (attn_factor != 1) { @@ -3116,7 +2155,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, } else { int64_t num_repeats = 2; int64_t dim = 3; - int64_t output_size = arange_length * num_repeats; + int64_t output_size = theta_scale_length * num_repeats; aclnn_repeat_interleave(ctx, acl_sin_tensor, acl_sin_repeat_tensor, dim, num_repeats, output_size); aclnn_repeat_interleave(ctx, acl_cos_tensor, acl_cos_repeat_tensor, dim, @@ -3124,13 +2163,8 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, } // release - ACL_CHECK(aclDestroyTensor(acl_arange_tensor)); - ACL_CHECK(aclDestroyTensor(acl_theta_scale_tensor)); - ACL_CHECK(aclDestroyTensor(acl_position_tensor)); - ACL_CHECK(aclDestroyTensor(acl_theta_tensor)); - ACL_CHECK(aclDestroyTensor(acl_permute_tensor)); - ACL_CHECK(aclDestroyTensor(acl_sin_tensor)); - ACL_CHECK(aclDestroyTensor(acl_cos_tensor)); + ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor, + acl_theta_tensor, acl_sin_tensor, acl_cos_tensor, acl_theta_scale); } #ifdef __cplusplus @@ -3152,7 +2186,6 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // TODO: use ascendc // Only test with LLAMA model. ggml_tensor* src0 = dst->src[0]; // input - ggml_tensor* src2 = dst->src[2]; // freq_factors // param float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; @@ -3187,13 +2220,13 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // init cos/sin cache ggml_cann_pool_alloc sin_allocator( - ctx.pool(), src0->ne[0] * src0->ne[2] * sizeof(float_t)); + ctx.pool(), ne00 * ne02 * sizeof(float_t)); ggml_cann_pool_alloc cos_allocator( - ctx.pool(), src0->ne[0] * src0->ne[2] * sizeof(float_t)); + ctx.pool(), ne00 * ne02 * sizeof(float_t)); void* sin_buffer = sin_allocator.get(); void* cos_buffer = cos_allocator.get(); - int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1}; + int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; size_t sin_reshape_nb[GGML_MAX_DIMS]; sin_reshape_nb[0] = sizeof(float_t); for (int i = 1; i < GGML_MAX_DIMS; i++) { @@ -3206,7 +2239,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor, - theta_scale, freq_scale, attn_factor, is_neox); + theta_scale, freq_scale, attn_factor, is_neox); aclTensor* acl_src = ggml_cann_create_tensor(src0); aclTensor* acl_dst = ggml_cann_create_tensor(dst); @@ -3243,8 +2276,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { int64_t shifts[] = {1}; int64_t dims[] = {3}; aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims); - ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor)); - ACL_CHECK(aclDestroyTensor(acl_input_tensor)); + ggml_cann_release_resources(ctx, acl_input_roll_tensor, acl_input_tensor); // init [-1, 1, -1, 1, ...] minus_one_scale_buffer = minus_one_scale_allocator.get(); @@ -3280,8 +2312,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { int64_t dims[] = {3}; aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims); - ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor)); - ACL_CHECK(aclDestroyTensor(acl_input_tensor)); + ggml_cann_release_resources(ctx, acl_input_roll_tensor, acl_input_tensor); // init [-1, -1, -1, 1, 1,1,...] minus_one_scale_buffer = minus_one_scale_allocator.get(); int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1}; @@ -3306,7 +2337,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { bool inplace = true; float scale = -1; aclnn_muls(ctx, acl_first_half_tensor, scale, nullptr, inplace); - ACL_CHECK(aclDestroyTensor(acl_first_half_tensor)); + ggml_cann_release_resources(ctx, acl_first_half_tensor); } // TODO: n_dims < ne0 @@ -3334,8 +2365,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // output void* output_fp32_buffer; if (src0->type == GGML_TYPE_F32) { - aclnn_inplace_mul(ctx, acl_src, acl_cos_reshape_tensor); - aclnn_inplace_mul(ctx, acl_input_roll_mul_scale_tensor, + aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor); + aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor); aclnn_add(ctx, acl_src, acl_input_roll_mul_scale_tensor, acl_dst); // TODO: ne0 != n_dims in mode2 @@ -3371,76 +2402,188 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { output_fp32_tensor); aclnn_cast(ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16); - ACL_CHECK(aclDestroyTensor(input_fp32_tensor1)); - ACL_CHECK(aclDestroyTensor(input_fp32_tensor2)); - ACL_CHECK(aclDestroyTensor(output_fp32_tensor)); - ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); - ACL_CHECK(aclDestroyTensor(acl_minus_one_tensor)); - ACL_CHECK(aclDestroyTensor(acl_input_roll_mul_scale_tensor)); - ACL_CHECK(aclDestroyTensor(acl_input_roll_reshape_tensor)); - ACL_CHECK(aclDestroyTensor(acl_src)); + ggml_cann_release_resources(ctx, input_fp32_tensor1, input_fp32_tensor2, + output_fp32_tensor, acl_sin_reshape_tensor, + acl_minus_one_tensor, acl_input_roll_mul_scale_tensor, + acl_input_roll_reshape_tensor, acl_src); } return; #endif - // src0 == GGML_TYPE_F16 - // TODO: optimization this `if` code - if (src0->type == GGML_TYPE_F16) { - ggml_cann_pool_alloc sin_final_allocator( - ctx.pool(), src0->ne[0] * src0->ne[2] * ggml_type_size(src0->type)); - ggml_cann_pool_alloc cos_final_allocator( - ctx.pool(), src0->ne[0] * src0->ne[2] * ggml_type_size(src0->type)); - void* sin_final_buffer = sin_final_allocator.get(); - void* cos_final_buffer = cos_final_allocator.get(); - - int64_t sin_final_ne[4] = {src0->ne[0], 1, src0->ne[2], 1}; - size_t sin_final_nb[GGML_MAX_DIMS]; - sin_final_nb[0] = ggml_type_size(src0->type); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - sin_final_nb[i] = sin_final_nb[i - 1] * sin_final_ne[i - 1]; + // ggml_mode = 0 --> aclnn_model = 1 + int64_t acl_mode = mode == 0 ? 1 : mode; + + switch (src0->type) { + case GGML_TYPE_F32: { + GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src, + acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, acl_dst); + break; } - aclTensor* acl_sin_final_tensor = ggml_cann_create_tensor( - sin_final_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), sin_final_ne, sin_final_nb, - GGML_MAX_DIMS); - aclTensor* acl_cos_final_tensor = ggml_cann_create_tensor( - cos_final_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), sin_final_ne, sin_final_nb, - GGML_MAX_DIMS); + case GGML_TYPE_F16: { + ggml_cann_pool_alloc src_trans_allocator( + ctx.pool(), ggml_nelements(src0) * sizeof(float)); + void* src_trans_buffer = src_trans_allocator.get(); + ggml_cann_pool_alloc dst_trans_allocator( + ctx.pool(), ggml_nelements(dst) * sizeof(float)); + void* dst_trans_buffer = dst_trans_allocator.get(); - aclnn_cast(ctx, acl_sin_reshape_tensor, acl_sin_final_tensor, - ggml_cann_type_mapping(src0->type)); - aclnn_cast(ctx, acl_cos_reshape_tensor, acl_cos_final_tensor, - ggml_cann_type_mapping(src0->type)); - ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor)); - ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); - acl_sin_reshape_tensor = acl_sin_final_tensor; - acl_cos_reshape_tensor = acl_cos_final_tensor; - } + size_t src_trans_nb[GGML_MAX_DIMS]; + src_trans_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + } + + aclTensor* acl_src_trans_tensor = ggml_cann_create_tensor( + src_trans_buffer, ACL_FLOAT, sizeof(float), src0->ne, src_trans_nb, + GGML_MAX_DIMS); + aclTensor* acl_dst_trans_tensor = ggml_cann_create_tensor( + dst_trans_buffer, ACL_FLOAT, sizeof(float), dst->ne, src_trans_nb, + GGML_MAX_DIMS); + + aclnn_cast(ctx, acl_src, acl_src_trans_tensor, ACL_FLOAT); - uint64_t workspaceSize = 0; - aclOpExecutor* executor; + GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor, + acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, + acl_dst_trans_tensor); - void* workspaceAddr = nullptr; + aclnn_cast(ctx, acl_dst_trans_tensor, acl_dst, ACL_FLOAT16); - int acl_mode = mode; - if (mode == 0) { - acl_mode = 1; + ggml_cann_release_resources(ctx, acl_src_trans_tensor, + acl_dst_trans_tensor); + break; + } + default: + GGML_ABORT("Unsupported tensor type for GGML_OP_ROPE"); + break; } + ggml_cann_release_resources(ctx, acl_cos_reshape_tensor, + acl_sin_reshape_tensor, acl_src, acl_dst); +} + + + void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + ggml_tensor * src0 = dst->src[0]; + + aclTensor* acl_src = ggml_cann_create_tensor(src0); + aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src, 3, false, acl_dst); + + ggml_cann_release_resources(ctx, acl_src, acl_dst); +} + +void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + + // stride + int64_t s0 = ((const int32_t*)(dst->op_params))[0]; + + aclTensor* acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL); + aclTensor* acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL); + aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL); + + int64_t strideVal[1]; + strideVal[0] = s0; + aclIntArray *stride = aclCreateIntArray(strideVal, 1); + int64_t paddingVal[] = {0}; + aclIntArray *padding = aclCreateIntArray(paddingVal, 1); + int64_t dilationVal[] = {1}; + aclIntArray *dilation = aclCreateIntArray(dilationVal, 1); + bool transposed = true; + int64_t groups = 1; + int8_t cubeMathType = 0; + +#ifdef ASCEND_310P + cubeMathType = 1; +#endif + + GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input, acl_weight, nullptr, stride, + padding, dilation, transposed, padding, groups, acl_dst, cubeMathType); + + ggml_cann_release_resources(ctx, acl_weight, acl_dst, stride, padding, dilation); +} + +void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + ggml_tensor * src0 = dst->src[0]; + + aclTensor* acl_input = ggml_cann_create_tensor(src0); + aclTensor* acl_dst = ggml_cann_create_tensor(dst); + + float alphaValue = 1.0f; + aclScalar* alpha = nullptr; + alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); + + GGML_CANN_CALL_ACLNN_OP(ctx, Elu, acl_input, alpha, alpha, alpha, + acl_dst); + + ggml_cann_release_resources(ctx, acl_input, acl_dst, alpha); +} + +void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + ggml_tensor * src0 = dst->src[0]; + + aclTensor* acl_src = ggml_cann_create_tensor(src0); + aclTensor* acl_dst = ggml_cann_create_tensor(dst); + + int64_t reduceDimValue[] = {3}; + aclIntArray* reduceDim = aclCreateIntArray(reduceDimValue, 1); + bool keepDim = true; + + GGML_CANN_CALL_ACLNN_OP(ctx, Mean, acl_src, reduceDim, keepDim, ACL_FLOAT, acl_dst); + + ggml_cann_release_resources(ctx, acl_src, acl_dst, reduceDim); +} + +void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + ggml_tensor * src0 = dst->src[0]; + int32_t *opts = (int32_t *) dst->op_params; + int64_t paddingsArray[2] = {opts[0], opts[1]}; + aclIntArray* paddings = aclCreateIntArray(paddingsArray, 2); + + for (int64_t i = 0; i < src0->ne[3]; i++) { + aclTensor* acl_src = ggml_cann_create_tensor( + (char*)src0->data + i * src0->ne[3], + ggml_cann_type_mapping(src0->type), ggml_element_size(src0), + src0->ne, src0->nb, 3); - ACL_CHECK(aclnnRotaryPositionEmbeddingGetWorkspaceSize( - acl_src, acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, - acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); + aclTensor* acl_dst = ggml_cann_create_tensor( + (char*)dst->data + i * src0->ne[3], + ggml_cann_type_mapping(dst->type), ggml_element_size(dst), + dst->ne, dst->nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src, paddings, acl_dst); + + ggml_cann_release_resources(ctx, acl_src, acl_dst); } + ggml_cann_release_resources(ctx, paddings); +} + +void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + + aclTensor* acl_self = ggml_cann_create_tensor(src0); + aclTensor* acl_other = ggml_cann_create_tensor(src1); + + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceEqTensor, acl_self, acl_other); + + ggml_cann_sum(ctx, dst); + + ggml_cann_release_resources(ctx, acl_self, acl_other); +} + +void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + ggml_tensor * src0 = dst->src[0]; + + aclTensor* acl_src = ggml_cann_create_tensor(src0); + aclTensor* acl_dst = ggml_cann_create_tensor(dst); + + float alphaValue = 0.0f; + aclScalar* alpha = nullptr; + alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); - ACL_CHECK(aclnnRotaryPositionEmbedding(workspaceAddr, workspaceSize, - executor, ctx.stream())); + GGML_CANN_CALL_ACLNN_OP(ctx, GtScalar, acl_src, alpha, acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor)); - ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_src, acl_dst, alpha); } diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 680129c76de..462351542e5 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -1,15 +1,4 @@ -#ifndef CANN_ACLNN_OPS -#define CANN_ACLNN_OPS - /** - * @file acl_tensor - * @brief This file contains related functions of ggml_tensor and acl_tensor. - * Contains conversion from ggml_tensor to acl_tensor, broadcast and other - * functions. - * @author hipudding - * @author wangshuai09 <391746016@qq.com> - * @date July 15, 2024 - * * Copyright (c) 2023-2024 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -31,20 +20,31 @@ * IN THE SOFTWARE. */ -#include +#ifndef CANN_ACLNN_OPS +#define CANN_ACLNN_OPS + +#include +#include +#include +#include #include #include #include #include -#include #include +#include +#include #include #include #include -#include #include #include #include +#include +#include +#include +#include +#include #include "acl_tensor.h" #include "common.h" @@ -63,23 +63,6 @@ */ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst); -/** - * @brief Adds two ggml tensors using the CANN backend. - * - * @details This function performs an element-wise addition of two tensors. In - * case the tensors do not have the same shape, one or both tensors - * will be broadcasted to match the shape of the other before the - * addition is performed.The formula for the operation is given by: - * \f[ - * \text{dst} = \text{acl_src0} + \alpha \cdot \text{acl_src1} - * \f] - * - * @param ctx The CANN context used for operations. - * @param dst The ggml tensor representing the destination, result of the - * addition is stored at dst->data, and dst->op is `GGML_OP_ADD` - */ -void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst); - /** * @brief Applies the Leaky ReLU activation function to a tensor using the CANN * backend. @@ -131,19 +114,6 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst); */ void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst); -/** - * @brief Computes the square of the elements of a ggml tensor using the CANN - * backend. - * @details The function sets the second source tensor of the destination - * tensor `dst` to be equal to the first source tensor. This is - * effectively squaring the elements since the multiplication becomes - * `element * element`. - * @param ctx The CANN context used for operations. - * @param dst The destination tensor where the squared values will be stored, - * which dst->op is `GGML_OP_SQR`. - */ -void ggml_cann_sqr(ggml_backend_cann_context& ctx, ggml_tensor* dst); - /** * @brief Applies a clamp operation to the elements of a ggml tensor using the * CANN backend. @@ -275,6 +245,20 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst); */ void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst); +/** + * @brief Computes the sum of elements in a ggml tensor. + * + * @details This function performs a reduction sum operation along the last + * dimension of the input tensor `src`. The result of the sum is stored + * in the destination tensor `dst`. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the reduced values will be stored。 + * + */ + +void ggml_cann_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst); + /** * @brief Upsamples a ggml tensor using nearest neighbor interpolation using * the CANN backend. @@ -484,109 +468,616 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst); */ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst); -template -void ggml_cann_mul_div(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src0 = dst->src[0]; - ggml_tensor* src1 = dst->src[1]; - GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); +/** + * @brief Computes the index of the maximum value along the specified dimension + * of a ggml tensor using the CANN backend. + * + * @details This function performs an argmax operation on the input tensor. + * It finds the index of the maximum value along the specified axis + * and stores these indices in the destination tensor `dst`. The + * operation is executed using the CANN backend for optimized performance. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the indices of the maximum values will + * be stored. dst->op is `GGML_OP_ARGMAX`. + */ +void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst); - aclTensor* acl_src0; - aclTensor* acl_src1; - aclTensor* acl_dst; +/** + * @brief Adds two tensors element-wise and stores the result in a destination + * tensor. + * + * This function performs the operation: + * \f[ + * dst = acl\_src0 + alpha \times acl\_src1 + * \f] + * where alpha is a scalar value and defaults to 1.0f. + * + * @param ctx The context for the CANN backend operations. + * @param acl_src0 The first source tensor. + * @param acl_src1 The second source tensor. + * @param acl_dst The destination tensor where the result will be stored. + */ +void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0, + aclTensor* acl_src1, aclTensor* acl_dst = nullptr); - // Need bcast - if (!ggml_are_same_shape(src0, src1) && ggml_cann_need_bcast(src0, src1)) { - BCAST_SHAPE(src0, src1) - acl_src0 = ggml_cann_create_tensor(src0, BCAST_PARAM(src0)); - acl_src1 = ggml_cann_create_tensor(src1, BCAST_PARAM(src1)); - acl_dst = ggml_cann_create_tensor(dst, BCAST_PARAM(src0)); - } else { - acl_src0 = ggml_cann_create_tensor(src0); - acl_src1 = ggml_cann_create_tensor(src1); - acl_dst = ggml_cann_create_tensor(dst); +/** + * @brief Sub two tensors element-wise and stores the result in a destination + * tensor. + * + * This function performs the operation: + * \f[ + * dst = acl\_src0 - alpha \times acl\_src1 + * \f] + * where alpha is a scalar value and defaults to 1.0f. + * + * @param ctx The context for the CANN backend operations. + * @param acl_src0 The first source tensor. + * @param acl_src1 The second source tensor. + * @param acl_dst The destination tensor where the result will be stored. + */ +void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0, + aclTensor* acl_src1, aclTensor* acl_dst = nullptr); + +/** + * @brief Performs element-wise multiplication of two tensors and stores the + * result in a destination tensor. + * + * This function performs element-wise multiplication of the tensors `acl_src` + * and `acl_other` and stores the result in the destination tensor `acl_dst`. + * The operation is defined as: + * \f[ + * \text {acl_dst }_i=\text {acl_src }_i \times \text {acl_other }_i + * \f] + * + * @param ctx The context for the CANN backend operations. + * @param acl_src The first tensor for element-wise multiplication. + * @param acl_other The second tensor for element-wise multiplication. + * @param acl_dst The destination tensor where the result will be stored. + */ +void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_other, aclTensor* acl_dst = nullptr); + +/** + * @brief Matrix division, optionally in-place. + * + * This function division each element of the source tensor `acl_src` by the + * tensor `acl_other` and stores the result in the destination tensor `acl_dst`. + * If `inplace` is true, `acl_dst` will not be used and the operation is + * performed in-place on `acl_src`. The operation is defined as: \f[ + * \text{dst}_i = \frac{\text{acl_src}_i}{\text{acl_other}_i} + * \f] + * + * @param ctx The context for the CANN backend operations. + * @param acl_src Numerator tensor.. + * @param acl_other Denominator tensor. + * @param acl_dst The destination tensor where the result will be stored if + * `inplace` is false. + * @param inplace Flag indicating whether to perform the operation in-place on + * `acl_src`. + */ +void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_other, aclTensor* acl_dst = nullptr); + +/** + * @brief Applies element-wise cosine function to the elements of a tensor. + * + * This function computes the cosine of each element in the source tensor + * `acl_src` and stores the result in the destination tensor `acl_dst`. The + * operation is defined as: \f[ \text {acl_dst }_i=\cos \left(\text {acl_src + * }_i\right) \f] + * + * @param ctx The context for the CANN backend operations. + * @param acl_src The source tensor on which the cosine function will be + * applied. + * @param acl_dst The destination tensor where the cosine results will be + * stored. + */ +void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_dst); + +/** + * @brief Applies element-wise sine function to the elements of a tensor. + * + * This function computes the sine of each element in the source tensor + `acl_src` + * and stores the result in the destination tensor `acl_dst`. + * The operation is defined as: + * \f[ + * \text {acl_dst }_i=\sin \left(\text {acl_src }_i\right) + * \f] + + * @param ctx The context for the CANN backend operations. + * @param acl_src The source tensor on which the sine function will be applied. + * @param acl_dst The destination tensor where the sine results will be stored. + */ +void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_dst); + +/** + * @brief Prepares broadcast-compatible ACL tensors for two input tensors and one + * output tensor. + * + * This function checks whether broadcasting is needed between `src0` and `src1`. + * If broadcasting is required, it calculates the proper shapes and creates + * ACL tensors with broadcast parameters. Otherwise, it directly creates ACL tensors + * based on the original tensor shapes. + * + * @param src0 The first input tensor (reference shape). + * @param src1 The second input tensor (possibly broadcasted). + * @param dst The destination/output tensor. + * @param acl_src0 Output pointer to the created ACL tensor corresponding to src0. + * @param acl_src1 Output pointer to the created ACL tensor corresponding to src1. + * @param acl_dst Output pointer to the created ACL tensor corresponding to dst. + */ +void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, + aclTensor ** acl_src0, aclTensor ** acl_src1, aclTensor ** acl_dst); + +/** + * @brief Computes the 1D transposed convolution (deconvolution) of a ggml + * tensor using the CANN backend. + * + * @details This function performs a 1D transposed convolution (also known as + * deconvolution) operation on the input tensor. The computed result is stored + * in the destination tensor `dst`. The operation is optimized using the CANN + * backend for improved performance. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the transposed convolution result + * will be stored. dst->op is `GGML_OP_CONV_TRANSPOSE_1D`. + */ +void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst); + +/** + * @brief Applies the ELU (Exponential Linear Unit) activation to a ggml tensor + * using the CANN backend. + * + * @details This function performs an element-wise ELU activation on the input + * tensor. + * The result is written to the destination tensor `dst` in-place. + * The ELU function is defined as: + * + * \text{ELU}(x) = + * \begin{cases} + * x, & \text{if } x > 0 \\ + * \alpha \left( \exp(x) - 1 \right), & \text{if } x \leq 0 + * \end{cases} + * + * where α (alpha) is a hyperparameter, typically set to 1.0. + * This operation is optimized using the CANN backend for high-performance + * inference or training. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the ELU-activated result will be stored. + * dst->op is expected to be `GGML_OP_ELU`. + */ +void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst); + +/** + * @brief Computes the mean of a ggml tensor element-wise using the CANN backend. + * + * @details This function calculates the element-wise mean of the input tensor. + * The result is written to the destination tensor `dst`. + * The mean is computed by averaging the values across the entire tensor. + * + * This operation is optimized using the CANN backend for high-performance inference or training. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the mean result will be stored. + * dst->op is expected to be `GGML_OP_MEAN`. + */ +void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst); + +/** + * @brief Applies 1D reflect padding to a ggml tensor using the CANN backend. + * + * @details This function performs 1D reflect padding on the input tensor. + * The amount of padding on each side is specified by parameters stored in `dst->op_params`. + * The operation reflects the values at the borders of the tensor to generate the padded output. + * + * This operation is optimized using the CANN backend for high-performance inference or training. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the padded result will be stored. + * dst->op is expected to be `GGML_OP_PAD_REFLECT_1D`. + */ +void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst); + +/** + * @brief Counts the number of equal elements in two ggml tensors using the CANN backend. + * + * @details This function performs an element-wise comparison between two input tensors, + * and counts the number of positions where the elements are equal. The result is + * stored in the destination tensor `dst` as a scalar. + * + * The operation is optimized using the CANN backend, making it suitable for + * high-performance inference or training scenarios. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the result will be stored. + * dst->op is expected to be `GGML_OP_COUNT_EQUAL`. + */ +void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst); + +/** + * @brief Applies the Step activation function to a ggml tensor using the CANN backend. + * + * @details This function applies a step function element-wise to the input tensor, where + * each element is transformed to 1.0 if it is greater than 0, and 0.0 otherwise. + * The result is stored in the destination tensor `dst`. + * + * This operation is accelerated using the CANN backend to improve runtime performance. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the result will be stored. + * dst->op is expected to be `GGML_OP_STEP`. + */ +void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst); + +/* + * @brief A generic wrapper for ACL resources with custom deleter support. + */ +using any_acl_resource = std::unique_ptr>; + +/** + * @brief Trait structure used to define how to destroy a given ACL resource type. + * + * @tparam T ACL resource type. + */ +template +struct acl_resource_traits; + +/** + * @brief Specialization for aclTensor, defines how to destroy an aclTensor resource. + */ +template<> +struct acl_resource_traits { + static void destroy(void* p) { + ACL_CHECK(aclDestroyTensor(static_cast(p))); } +}; - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; +/** + * @brief Specialization for aclIntArray, defines how to destroy an aclIntArray resource. + */ +template<> +struct acl_resource_traits { + static void destroy(void* p) { + ACL_CHECK(aclDestroyIntArray(static_cast(p))); + } +}; - ACL_CHECK(getWorkspaceSize(acl_src0, acl_src1, acl_dst, &workspaceSize, - &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); +/** + * @brief Specialization for aclScalar, defines how to destroy an aclScalar resource. + */ +template<> +struct acl_resource_traits { + static void destroy(void* p) { + ACL_CHECK(aclDestroyScalar(static_cast(p))); } +}; - aclrtStream main_stream = ctx.stream(); - ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream)); +/** + * @brief Specialization for aclTensorList, defines how to destroy an aclTensorList resource. + */ +template<> +struct acl_resource_traits { + static void destroy(void* p) { + ACL_CHECK(aclDestroyTensorList(static_cast(p))); + } +}; - ACL_CHECK(aclDestroyTensor(acl_src0)); - ACL_CHECK(aclDestroyTensor(acl_src1)); - ACL_CHECK(aclDestroyTensor(acl_dst)); +/** + * @brief Creates a generic ACL resource wrapper with proper destruction logic. + * + * @tparam T ACL resource type. + * @param ptr Raw pointer to ACL resource. + * @return any_acl_resource Smart pointer that handles destruction. + */ +template +any_acl_resource make_acl_resource(T* ptr) { + return any_acl_resource( + static_cast(ptr), + [](void* p) { + acl_resource_traits::destroy(p); + } + ); } -// Activation functions template. -template -void ggml_cann_activation(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; +/** + * @brief Registers multiple ACL resources into a vector for lifetime management. + * + * @tparam Args Variadic list of ACL resource types. + * @param vec Target vector to hold ACL resources. + * @param args Raw pointers to ACL resources. + */ +template +void register_acl_resources(std::vector& vec, Args*... args) { + (vec.emplace_back(make_acl_resource(args)), ...); +} - GGML_ASSERT(src->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); +/** + * @brief Task class that wraps the execution of an aclnn function call. + */ +class aclnn_task : public cann_task { + public: + aclnn_task(aclnn_func_t aclnn_func, void * workspace_addr, + uint64_t workspace_size, aclOpExecutor * executor, + aclrtStream stream) : + aclnn_func_(aclnn_func), + workspace_addr_(workspace_addr), + workspace_size_(workspace_size), + executor_(executor), + stream_(stream) {} + virtual void run_task() override { + ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_)); + } + private: + aclnn_func_t aclnn_func_; + void * workspace_addr_; + uint64_t workspace_size_; + aclOpExecutor * executor_; + aclrtStream stream_; +}; - aclTensor* acl_src = ggml_cann_create_tensor(src); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); +/** + * @brief Task class that releases ACL resources after usage. + */ +class release_resource_task : public cann_task { +public: + release_resource_task(std::vector&& resources){ + resource_ = std::move(resources); + } - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; + virtual void run_task() override { + resource_.clear(); + } +private: + std::vector resource_; +}; + +/** + * @brief Task class for performing asynchronous memory copy operations. + */ +class async_memcpy_task : public cann_task { +public: + async_memcpy_task(void* dst, const void* src, size_t size, + aclrtMemcpyKind kind, aclrtStream stream) + : dst_(dst), src_(src), size_(size), kind_(kind), stream_(stream) {} - ACL_CHECK(getWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); + virtual void run_task() override { + ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_)); } +private: + void* dst_; + const void* src_; + size_t size_; + aclrtMemcpyKind kind_; + aclrtStream stream_; +}; - aclrtStream main_stream = ctx.stream(); - ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream)); +/** + * @brief Task class for performing asynchronous memory set operations. + */ +class async_memset_task : public cann_task { + public: + async_memset_task(void* buffer, size_t size, int32_t value, aclrtStream stream) + : buffer_(buffer), size_(size), value_(value), stream_(stream) {} - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); -} + virtual void run_task() override { + ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_)); + } + private: + void* buffer_; + size_t size_; + int32_t value_; + aclrtStream stream_; +}; -// Activation functions template for const aclTensors. -template -void ggml_cann_activation(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - ggml_tensor* src = dst->src[0]; +/** + * @brief Launches an asynchronous task using the memory allocator. + * + * This macro submit an asynchronous task on the specified stream. + * The task uses memory allocated by the allocator. It is guaranteed + * that the memory will not be accessed by other tasks until this task + * completes, due to the sequential execution order within the same stream. + * + * @param OP_NAME aclnn operator name. + * @param args Additional arguments required by the task. + * + * @note + * Memory from the allocator will be "freed" immediately and can be + * reallocated to other pointers. However, it won't be accessed by any + * other task before this asynchronous task ends, because all tasks in the + * same stream are executed in queue order. + */ - GGML_ASSERT(src->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); +#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \ + do { \ + uint64_t workspaceSize = 0; \ + aclOpExecutor * executor; \ + void * workspaceAddr = nullptr; \ + ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor));\ + /* workspace should alloced in main thread to keep malloc order when using vmm. */ \ + if (workspaceSize > 0) { \ + ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \ + workspaceAddr = workspace_allocator.get(); \ + } \ + if (CTX.async_mode) { \ + auto task = \ + std::make_unique(aclnn##OP_NAME, workspaceAddr, workspaceSize, \ + executor, CTX.stream()); \ + CTX.task_queue.submit_task(std::move(task)); \ + } else { \ + ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream()));\ + } \ + } while (0) - aclTensor* acl_src = ggml_cann_create_tensor(src); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); +/** + * @brief Registers and releases multiple ACL resources, optionally deferring the release + * using a task. + * + * @tparam Args Types of the ACL resources. + * @param ctx Backend context which manages task submission and async mode. + * @param args Pointers to ACL resources to be released. + */ +template +void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) { + std::vector resources; + register_acl_resources(resources, std::forward(args)...); + if(ctx.async_mode) { + auto task = std::make_unique(std::move(resources)); + ctx.task_queue.submit_task(std::move(task)); + } +} + +/** + * @brief Performs an asynchronous memory copy operation, optionally deferred via task submission. + * + * @param ctx Backend context containing stream and async configuration. + * @param dst Destination memory address. + * @param src Source memory address. + * @param len Size of memory to copy (in bytes). + * @param kind Type of memory copy (host-to-device, device-to-host, etc). + */ +inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx, void * dst, + const void * src, size_t len, aclrtMemcpyKind kind) { + if (ctx.async_mode) { + auto task = std::make_unique(dst, const_cast(src), len, kind, ctx.stream()); + ctx.task_queue.submit_task(std::move(task)); + } else { + ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx.stream())); + } +} - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; +inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx, void * dst, + const void * src, size_t len, aclrtMemcpyKind kind) { + if (ctx->async_mode) { + auto task = std::make_unique(dst, const_cast(src), len, kind, ctx->stream()); + ctx->task_queue.submit_task(std::move(task)); + } else { + ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx->stream())); + } +} - ACL_CHECK(getWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor)); - if (workspaceSize > 0) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); - workspaceAddr = workspace_allocator.get(); +/** + * @brief Performs an asynchronous memory set operation, optionally deferred via task submission. + * + * @param ctx Backend context containing stream and async configuration. + * @param buffer Memory buffer to be set. + * @param size Size of the memory buffer (in bytes). + * @param value Value to set in the buffer. + */ +inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer, + size_t size, int value) { + if (ctx.async_mode) { + auto task = std::make_unique(buffer, size, value, ctx.stream()); + ctx.task_queue.submit_task(std::move(task)); + } else { + ACL_CHECK(aclrtMemsetAsync(buffer, size, value, size, ctx.stream())); } +} - aclrtStream main_stream = ctx.stream(); - ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream)); +/** + * @brief Applies a element-wise operation to two input tensors using the CANN + * backend. + * + * This templated function takes a binary operator and applies it to two source + * tensors + * associated with the destination tensor. The function handles broadcasting as + * needed. + * + * @tparam binary_op A callable object (e.g., lambda or function pointer) representing + * the binary operation to be performed. It must take three arguments: + * (ggml_backend_cann_context&, aclTensor*, aclTensor*, aclTensor*). + * + * @param ctx The CANN backend context used to manage execution and resources. + * @param dst The destination tensor. + */ +template +void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + ggml_tensor* src0 = dst->src[0]; + ggml_tensor* src1 = dst->src[1]; - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + aclTensor* acl_src0; + aclTensor* acl_src1; + aclTensor* acl_dst; + + // Need bcast + bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst); + binary_op(ctx, acl_src0, acl_src1, acl_dst); + + ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst); } + +/** + * @brief Applies a unary operation to an input tensor using the CANN backend. + * + * This templated function applies a unary operator to the source tensor of `dst` + * and stores the result in the destination tensor. + * + * @tparam unary_op A callable with the signature: + * void(ggml_backend_cann_context&, aclTensor*, aclTensor*) + * where the first aclTensor is the source and the second is the destination. + * @param ctx The CANN backend context for managing resources and execution. + * @param dst The destination tensor. Its src[0] is treated as the input tensor. + */ +template + void ggml_cann_unary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + ggml_tensor* src = dst->src[0]; + + aclTensor* acl_src = ggml_cann_create_tensor(src); + aclTensor* acl_dst = ggml_cann_create_tensor(dst); + + unary_op(ctx, acl_src, acl_dst); + ggml_cann_release_resources(ctx, acl_src, acl_dst); +} + +/** + * @brief Applies a unary operation to a ggml tensor using the CANN backend. + * + * @details This function performs a unary operation on the input tensor using + * a user-provided lambda or callable object `unary_op`, which accepts the CANN + * context and two ACL tensors (source and destination). Internally, this function + * creates ACL representations of the ggml tensors and invokes the unary operation. + * The result is stored in the destination tensor `dst`. This utility abstracts the + * common boilerplate of tensor conversion and cleanup when implementing unary ops. + * + * @param unary_op A callable that performs the unary operation using CANN APIs. + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the result will be stored. + * The source tensor is retrieved from `dst->src[0]`. + */ +void ggml_cann_unary_op( + std::function unary_op, + ggml_backend_cann_context& ctx, ggml_tensor* dst); + +/** + * @brief Helper macro to invoke a unary ACL operation using ggml_cann_unary_op. + * + * This macro defines an inline lambda wrapping a specific ACL operation name, + * and passes it to the templated ggml_cann_unary_op function. It simplifies + * calling unary ops by hiding the lambda boilerplate. + * + * Internally, the lambda will call: + * @code + * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); + * @endcode + * + * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP. + * + * @see ggml_cann_unary_op + * @see GGML_CANN_CALL_ACLNN_OP + */ +#define GGML_CANN_CALL_UNARY_OP(OP_NAME) \ + do { \ + auto lambda = [](ggml_backend_cann_context& ctx, \ + aclTensor* acl_src, \ + aclTensor* acl_dst) { \ + GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \ + }; \ + ggml_cann_unary_op(lambda, ctx, dst); \ + } \ + while (0) #endif // CANN_ACLNN_OPS diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 5164cb74ec9..7ef80a47933 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -31,9 +31,16 @@ #include #include #include +#include +#include +#include +#include +#include +#include #include "../include/ggml-cann.h" #include "../include/ggml.h" +#include "../ggml-impl.h" #define MATRIX_ROW_PADDING 512 #define GGML_CANN_MAX_STREAMS 8 @@ -205,6 +212,127 @@ struct ggml_cann_pool_alloc { ggml_cann_pool_alloc& operator=(ggml_cann_pool_alloc&&) = delete; }; +/** + * @brief Function pointer type for ACLNN operator calls. + */ +using aclnn_func_t = aclnnStatus (*)(void*, uint64_t, aclOpExecutor*, aclrtStream); + +/** + * @brief Base class for all CANN tasks to be submitted to the task queue. + * + * Users should override the run_task() method with actual task logic. + */ +class cann_task { +public: + virtual void run_task() {} +}; + +/** + * @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances. + */ +class cann_task_queue { +public: + /** + * @brief Constructs a task queue with a fixed power-of-two capacity for a specific device. + * + * @param capacity Queue capacity. Must be a power of 2. + * @param device Target device ID (used for context setting). + */ + explicit cann_task_queue(size_t capacity, int32_t device) + : buffer_(capacity), capacity_(capacity), head_(0), tail_(0), + running_(false), device_(device) { + GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2"); + mask_ = capacity_ - 1; + } + + /** + * @brief Attempts to enqueue a task into the queue. + * + * @param item Unique pointer to the task. + * @return true if the task was successfully enqueued, false if the queue was full. + */ + bool enqueue(std::unique_ptr&& item) { + size_t next_tail = (tail_ + 1) & mask_; + + if (next_tail == head_) { + return false; + } + + buffer_[tail_] = std::move(item); + std::atomic_thread_fence(std::memory_order_release); + tail_ = next_tail; + + return true; + } + + /** + * @brief Submits a task to the queue, and starts the worker thread if not already running. + * + * @param task Task to be submitted. + */ + void submit_task(std::unique_ptr&& task) { + while(!enqueue(std::move(task))) { + std::this_thread::yield(); + continue; + } + + if (!running_) { + running_ = true; + thread_ = std::thread(&cann_task_queue::execute, this); + } + + } + + /** + * @brief Waits until the queue is completely empty and no tasks are being processed. + */ + void wait() { + while (running_ && head_ != tail_) { + std::this_thread::yield(); + continue; + } + } + + /** + * @brief Stops the task queue and joins the worker thread. + */ + void stop() { + running_ = false; + if (thread_.joinable()) { + thread_.join(); + } + } + +private: + /** + * @brief Worker thread function that continuously dequeues and executes tasks. + */ + void execute() { + ggml_cann_set_device(device_); + + while (running_) { + if(head_ == tail_) { + std::this_thread::yield(); + continue; + } + + std::atomic_thread_fence(std::memory_order_acquire); + buffer_[head_]->run_task(); + buffer_[head_].reset(); + head_ = (head_ + 1) & mask_; + } + } + + std::vector> buffer_; + const size_t capacity_; + size_t mask_; + size_t head_; + size_t tail_; + bool running_; + std::thread thread_; + int32_t device_; +}; + /** * @brief Context for managing CANN backend operations. */ @@ -213,6 +341,8 @@ struct ggml_backend_cann_context { std::string name; /**< Name of the device. */ std::string description; /**< Description of the device. */ aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */ + cann_task_queue task_queue; + bool async_mode; aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */ @@ -221,9 +351,12 @@ struct ggml_backend_cann_context { * @param device Device ID. */ explicit ggml_backend_cann_context(int device) - : device(device), name("CANN" + std::to_string(device)) { + : device(device), name("CANN" + std::to_string(device)), task_queue(1024, device) { ggml_cann_set_device(device); description = aclrtGetSocName(); + async_mode = (getenv("GGML_CANN_ASYNC_MODE") != nullptr); + GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, + device, async_mode ? "ON" : "OFF"); } /** @@ -231,6 +364,7 @@ struct ggml_backend_cann_context { */ ~ggml_backend_cann_context() { ggml_cann_set_device(device); + task_queue.stop(); if (copy_event != nullptr) { ACL_CHECK(aclrtDestroyEvent(copy_event)); } diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index da75f77f511..e2617b06e9c 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -29,6 +29,8 @@ #include #include #include +#include +#include #include "ggml-impl.h" #include "ggml-backend-impl.h" @@ -119,9 +121,10 @@ static ggml_cann_device_info ggml_cann_init() { prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE; prop.location.id = id; prop.reserve = 0; - ACL_CHECK(aclrtMemGetAllocationGranularity( + err = aclrtMemGetAllocationGranularity( &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED, - &info.devices[id].vmm_granularity)); + &info.devices[id].vmm_granularity); + info.devices[id].vmm = err == ACL_SUCCESS; size_t free, total; ggml_backend_cann_get_device_memory(id, &free, &total); @@ -148,11 +151,223 @@ const ggml_cann_device_info& ggml_cann_info() { //#define DEBUG_CANN_MALLOC /** - * @brief A pool of CANN buffers(legacy). + * @brief A pool of CANN buffers(priority segment buffer). * * This class manages a pool of CANN buffers for a specific device. */ -struct ggml_cann_pool_leg : public ggml_cann_pool { +struct ggml_cann_pool_buf_prio : public ggml_cann_pool { + /** + * @brief The maximum reuse margin for a buffer. + */ + static const size_t max_reuse_margin = 1ull << 22; // 4MB + + /** + * @brief The minimum free margin for a buffer. + */ + static const size_t min_free_margin = 1ull << 20; // 1MB + + /** + * @brief The alignment for buffer allocation. + */ + static const size_t alignment = 128; + + /** + * @brief The device ID associated with this buffer pool. + */ + int device; + + /** + * @brief Whether to disable clean during buffer allocation. + */ + bool disable_clean = false; + + /** + * @brief Structure representing a CANN buffer. + */ + struct ggml_cann_buffer { + void* ptr = nullptr; ///< Pointer to the buffer. + size_t size = 0; ///< Size of the buffer. + std::chrono::steady_clock::time_point last_used; ///< Last used time. + + bool operator>(const ggml_cann_buffer& other) const { + return size > other.size; + } + }; + + /** + * @brief Array of CANN buffers in the pool. + */ + std::unordered_map buffer_pool; + std::priority_queue, + std::greater<>> free_buffers ; + + /** + * @brief Total size of all buffers in the pool. + */ + size_t pool_size = 0; + + /** + * @brief Constructor to initialize the buffer pool for a specific device. + * + * @param device The device ID to associate with this buffer pool. + */ + explicit ggml_cann_pool_buf_prio(int device) : device(device) { + disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr; + } + + /** + * @brief Destructor to free all buffers in the pool. + */ + ~ggml_cann_pool_buf_prio() { + ggml_cann_set_device(device); + for (auto& [b_ptr, b_size] : buffer_pool) { + aclrtFree(b_ptr); + pool_size -= b_size; + } + buffer_pool.clear(); + GGML_ASSERT(pool_size == 0); + } + + /** + * @brief Allocate a buffer of the given size. + * + * @param size The size of the buffer to allocate. + * @param actual_size A pointer to a variable to receive the actual size of + * the allocated buffer. + * @return A pointer to the allocated buffer. + */ + void* alloc(size_t size, size_t* actual_size) override { + size = GGML_PAD(size, alignment); + if (size == 0) { + size = alignment; + } + + void* ptr = nullptr; + auto now = std::chrono::steady_clock::now(); + + std::vector free_buffers_rest; + free_buffers_rest.reserve(free_buffers.size()); + while (!free_buffers.empty()) { + auto b = free_buffers.top(); + free_buffers.pop(); + + if (b.size >= size) { + // reuse the buffer if the size is enough + const size_t margin = b.size - size; + if (margin <= max_reuse_margin) { + *actual_size = b.size; + ptr = b.ptr; +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: reused %p, " + "pool_size = %5u MB, " + "size = %5u MB, " + "margin = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(margin, 1048576) / 1048576)); +#endif + break; + } + } + + bool should_clean = !disable_clean && + b.size > min_free_margin && + std::chrono::duration_cast(now - b.last_used).count() > 100; + if (should_clean) { + // free the buffer if the size is needed to be freed + ACL_CHECK(aclrtFree(b.ptr)); + pool_size -= b.size; + buffer_pool.erase(b.ptr); +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: clean %p, " + "pool_size = %5u MB, " + "size = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576)); +#endif + continue; + } + free_buffers_rest.push_back(b); + } + for (ggml_cann_buffer &b : free_buffers_rest) { + free_buffers.push(std::move(b)); + } + +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO("cann pool[%d] free pool_size = %5u MB\n\n", device, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576)); +#endif + if (ptr != nullptr) { + return ptr; + } + + // allocate a new buffer if no buffer can be reused + ggml_cann_set_device(device); + ACL_CHECK(aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST)); + *actual_size = size; + pool_size += size; +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: allocate %p, " + "pool_size = %5u MB, " + "size = %5u MB\n", + device, ptr, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(size, 1048576) / 1048576)); +#endif + buffer_pool.emplace(ptr, size); + return ptr; + } + + /** + * @brief Free a buffer and return it to the pool. + * + * @param ptr Pointer to the buffer to free. + * @param size Size of the buffer to free. + */ + void free(void* ptr, size_t size) override { + GGML_UNUSED(size); + auto it = buffer_pool.find(ptr); + if (it == buffer_pool.end()) { + GGML_ABORT("cann pool[%d]: buffer %p not found in pool\n", device, ptr); + } + + auto now = std::chrono::steady_clock::now(); + free_buffers.emplace(ggml_cann_buffer{ptr, it->second, now}); +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: return %p, " + "pool_size = %5u MB\n", + device, ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576)); +#endif + } +}; + +/** + * @brief A pool of CANN buffers(segment buffer). + * + * This class manages a pool of CANN buffers for a specific device. + */ +struct ggml_cann_pool_buf : public ggml_cann_pool { + /** + * @brief The maximum reuse margin for a buffer. + */ + static const size_t max_reuse_margin = 1ull << 22; // 4MB + + /** + * @brief The minimum free margin for a buffer. + */ + static const size_t min_free_margin = 1ull << 20; // 1MB + + /** + * @brief The alignment for buffer allocation. + */ + static const size_t alignment = 128; + /** * @brief The maximum number of buffers in the pool. */ @@ -163,12 +378,19 @@ struct ggml_cann_pool_leg : public ggml_cann_pool { */ int device; + /** + * @brief Whether to disable clean during buffer allocation. + */ + bool disable_clean = false; + /** * @brief Structure representing a CANN buffer. */ struct ggml_cann_buffer { void* ptr = nullptr; ///< Pointer to the buffer memory. size_t size = 0; ///< Size of the buffer. + bool used = false; ///< Whether the buffer is currently in use. + std::chrono::steady_clock::time_point last_used; ///< Last used time. }; /** @@ -186,17 +408,19 @@ struct ggml_cann_pool_leg : public ggml_cann_pool { * * @param device The device ID to associate with this buffer pool. */ - explicit ggml_cann_pool_leg(int device) : device(device) {} + explicit ggml_cann_pool_buf(int device) : device(device) { + disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr; + } /** * @brief Destructor to free all buffers in the pool. */ - ~ggml_cann_pool_leg() { + ~ggml_cann_pool_buf() { ggml_cann_set_device(device); for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cann_buffer& b = buffer_pool[i]; if (b.ptr != nullptr) { - ACL_CHECK(aclrtFree(b.ptr)); + aclrtFree(b.ptr); pool_size -= b.size; } } @@ -212,63 +436,93 @@ struct ggml_cann_pool_leg : public ggml_cann_pool { * @return A pointer to the allocated buffer. */ void* alloc(size_t size, size_t* actual_size) override { - const size_t alignment = 128; size = GGML_PAD(size, alignment); if (size == 0) { size = alignment; } -#ifdef DEBUG_CANN_MALLOC - int nnz = 0; - size_t max_size = 0; -#endif - size_t best_diff = 1ull << 36; - int ibest = -1; - for (int i = 0; i < MAX_BUFFERS; ++i) { + + void* ptr = nullptr; + auto now = std::chrono::steady_clock::now(); + + int i = 0; + for (; i < MAX_BUFFERS; ++i) { ggml_cann_buffer& b = buffer_pool[i]; - if (b.ptr != nullptr) { + if (b.ptr == nullptr) { + break; + } + if (b.used) { + continue; + } + if (b.size >= size) { + // reuse the buffer if the size is enough + const size_t margin = b.size - size; + if (margin <= max_reuse_margin) { + *actual_size = b.size; + b.used = true; + ptr = b.ptr; #ifdef DEBUG_CANN_MALLOC - ++nnz; - if (b.size > max_size) max_size = b.size; + GGML_LOG_INFO( + "cann pool[%d]: reused %p, " + "pool_size = %5u MB, " + "size = %5u MB, " + "margin = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(margin, 1048576) / 1048576)); #endif - if (b.size >= size) { - size_t diff = b.size - size; - if (diff < best_diff) { - best_diff = diff; - ibest = i; - if (!best_diff) { - void* ptr = b.ptr; - *actual_size = b.size; - b.ptr = nullptr; - b.size = 0; - return ptr; - } - } + break; } } + + bool should_clean = !disable_clean && + b.size > min_free_margin && + std::chrono::duration_cast(now - b.last_used).count() > 100; + if (should_clean) { + // free the buffer if the size is needed to be freed + ACL_CHECK(aclrtFree(b.ptr)); + pool_size -= b.size; +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: clean %p, " + "pool_size = %5u MB, " + "size = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576)); +#endif + b.ptr = nullptr; + } } - if (ibest >= 0) { - ggml_cann_buffer& b = buffer_pool[ibest]; - void* ptr = b.ptr; - *actual_size = b.size; - b.ptr = nullptr; - b.size = 0; + if (ptr != nullptr) { return ptr; } - void* ptr; - ggml_cann_set_device(device); - ACL_CHECK( - aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST)); - *actual_size = size; - pool_size += size; + + if (i < MAX_BUFFERS) { + // allocate a new buffer if no buffer can be reused + ggml_cann_buffer& b = buffer_pool[i]; + ggml_cann_set_device(device); + ACL_CHECK(aclrtMalloc(&b.ptr, size, ACL_MEM_MALLOC_HUGE_FIRST)); + pool_size += size; + *actual_size = size; + b.size = size; + b.used = true; + if (i >= MAX_BUFFERS - 8) { + GGML_LOG_WARN("cann pool[%d]: slots almost full\n", device); + } #ifdef DEBUG_CANN_MALLOC - GGML_LOG_INFO( - "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, " - "requested %u MB\n", - __func__, device, nnz, (uint32_t)(max_size / 1024 / 1024), - (uint32_t)(pool_size / 1024 / 1024), - (uint32_t)(size / 1024 / 1024)); + GGML_LOG_INFO( + "cann pool[%d]: allocate %p, " + "pool_size = %5u MB, " + "size = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576)); #endif - return ptr; + return b.ptr; + } + + GGML_ABORT("cann pool[%d]: slots full\n", device); } /** @@ -278,18 +532,24 @@ struct ggml_cann_pool_leg : public ggml_cann_pool { * @param size Size of the buffer to free. */ void free(void* ptr, size_t size) override { + GGML_UNUSED(size); for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cann_buffer& b = buffer_pool[i]; - if (b.ptr == nullptr) { - b.ptr = ptr; - b.size = size; - return; + if (b.ptr != ptr) { + continue; } + b.used = false; + b.last_used = std::chrono::steady_clock::now(); +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: return %p, " + "pool_size = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576)); +#endif + return; } - // memory should always buffered. these memory may still needed by - // tasks in stream. - // TODO, fix me. - GGML_ABORT("Cann buffer pool full, increase MAX_CANN_BUFFERS\n"); + GGML_ABORT("cann pool[%d]: slots full\n", device); } }; @@ -347,8 +607,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { * @param device The device ID to associate with this buffer pool. */ explicit ggml_cann_pool_vmm(int device) - : device(device), - granularity(ggml_cann_info().devices[device].vmm_granularity) { + : device(device) { auto dev = ggml_cann_info().devices[device]; granularity = dev.vmm_granularity; max_size = dev.total_vram; @@ -471,7 +730,18 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { */ std::unique_ptr ggml_backend_cann_context::new_pool_for_device( int device) { - return std::unique_ptr(new ggml_cann_pool_vmm(device)); + bool disable_vmm = (getenv("GGML_CANN_DISABLE_VMM_POOL") != nullptr); + if (!disable_vmm && ggml_cann_info().devices[device].vmm) { + GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device); + return std::unique_ptr(new ggml_cann_pool_vmm(device)); + } + bool enable_buf_prio = (getenv("GGML_CANN_ENABLE_BUF_PRIO_POOL") != nullptr); + if (enable_buf_prio) { + GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device); + return std::unique_ptr(new ggml_cann_pool_buf_prio(device)); + } + GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device); + return std::unique_ptr(new ggml_cann_pool_buf(device)); } // cann buffer @@ -803,7 +1073,7 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor( return GGML_STATUS_SUCCESS; } - // TODO: can backend doesn't support quantized yet. Just leave the code + // TODO: cann backend doesn't support quantized yet. Just leave the code // here. if (ggml_is_quantized(tensor->type)) { // Initialize padding to 0 to avoid possible NaN values @@ -1020,8 +1290,11 @@ ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, ggml_cann_set_device(buft_ctx->device); - size = std::max(size, (size_t)1); - + const size_t alignment = 128; + size = GGML_PAD(size, alignment); + if (size == 0) { + size = alignment; + } void* dev_ptr; aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST); if (err != ACL_SUCCESS) { @@ -1300,47 +1573,69 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, ggml_cann_dup(ctx, dst); break; case GGML_OP_ADD: - ggml_cann_add(ctx, dst); + case GGML_OP_ADD1: + ggml_cann_binary_op(ctx, dst); + break; + case GGML_OP_SUB: + ggml_cann_binary_op(ctx, dst); break; case GGML_OP_ACC: ggml_cann_acc(ctx, dst); break; case GGML_OP_MUL: - ggml_cann_mul_div(ctx, dst); + ggml_cann_binary_op(ctx, dst); break; case GGML_OP_DIV: - ggml_cann_mul_div(ctx, dst); + ggml_cann_binary_op(ctx, dst); break; case GGML_OP_UNARY: switch (ggml_get_unary_op(dst)) { + case GGML_UNARY_OP_ABS: + GGML_CANN_CALL_UNARY_OP(Abs); + break; + case GGML_UNARY_OP_NEG: + GGML_CANN_CALL_UNARY_OP(Neg); + break; case GGML_UNARY_OP_GELU: - ggml_cann_activation( - ctx, dst); + GGML_CANN_CALL_UNARY_OP(Gelu); break; case GGML_UNARY_OP_SILU: - ggml_cann_activation( - ctx, dst); - break; - // TODO: Use faster gelu?? - case GGML_UNARY_OP_GELU_QUICK: - ggml_cann_activation( - ctx, dst); + GGML_CANN_CALL_UNARY_OP(Silu); break; + case GGML_UNARY_OP_GELU_QUICK: { + auto lambda = [](ggml_backend_cann_context& ctx, + aclTensor* acl_src, + aclTensor* acl_dst) { + GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst); + }; + ggml_cann_unary_op(lambda, ctx, dst); + } break; case GGML_UNARY_OP_TANH: - ggml_cann_activation( - ctx, dst); + GGML_CANN_CALL_UNARY_OP(Tanh); break; case GGML_UNARY_OP_RELU: - ggml_cann_activation( - ctx, dst); + GGML_CANN_CALL_UNARY_OP(Relu); + break; + case GGML_UNARY_OP_SIGMOID: + GGML_CANN_CALL_UNARY_OP(Sigmoid); break; case GGML_UNARY_OP_HARDSIGMOID: - ggml_cann_activation(ctx, dst); + GGML_CANN_CALL_UNARY_OP(Hardsigmoid); break; case GGML_UNARY_OP_HARDSWISH: - ggml_cann_activation(ctx, dst); + GGML_CANN_CALL_UNARY_OP(Hardswish); + break; + case GGML_UNARY_OP_EXP: + GGML_CANN_CALL_UNARY_OP(Exp); + break; + case GGML_UNARY_OP_ELU: + ggml_cann_elu(ctx, dst); + break; + case GGML_UNARY_OP_SGN: + GGML_CANN_CALL_UNARY_OP(Sign); + break; + case GGML_UNARY_OP_STEP: + ggml_cann_step(ctx, dst); break; default: return false; @@ -1382,7 +1677,12 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, ggml_cann_scale(ctx, dst); break; case GGML_OP_SQR: - ggml_cann_sqr(ctx, dst); + GGML_ASSERT(dst->src[1] == nullptr); + dst->src[1] = dst->src[0]; + ggml_cann_binary_op(ctx, dst); + break; + case GGML_OP_SQRT: + GGML_CANN_CALL_UNARY_OP(Sqrt); break; case GGML_OP_CLAMP: ggml_cann_clamp(ctx, dst); @@ -1414,12 +1714,39 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_OP_POOL_2D: ggml_cann_pool2d(ctx, dst); break; + case GGML_OP_SUM: + ggml_cann_sum(ctx, dst); + break; case GGML_OP_SUM_ROWS: ggml_cann_sum_rows(ctx, dst); break; case GGML_OP_ARGSORT: ggml_cann_argsort(ctx, dst); break; + case GGML_OP_ARGMAX: + ggml_cann_argmax(ctx, dst); + break; + case GGML_OP_COS: + ggml_cann_unary_op(ctx, dst); + break; + case GGML_OP_SIN: + ggml_cann_unary_op(ctx, dst); + break; + case GGML_OP_CONV_TRANSPOSE_1D: + ggml_cann_conv_transpose_1d(ctx, dst); + break; + case GGML_OP_LOG: + GGML_CANN_CALL_UNARY_OP(Log); + break; + case GGML_OP_MEAN: + ggml_cann_mean(ctx, dst); + break; + case GGML_OP_PAD_REFLECT_1D: + ggml_cann_pad_reflect_1d(ctx, dst); + break; + case GGML_OP_COUNT_EQUAL: + ggml_cann_count_equal(ctx, dst); + break; default: return false; } @@ -1458,21 +1785,15 @@ static void ggml_backend_cann_free(ggml_backend_t backend) { ACL_CHECK(aclrtSynchronizeDevice()); ACL_CHECK(aclrtResetDevice(cann_ctx->device)); - // finalize when last backend freed. - if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) { - ACL_CHECK(aclFinalize()); - } - delete cann_ctx; delete backend; } + /** * @brief Sets tensor data asynchronously in the CANN backend. * - * This function asynchronously sets tensor data in the CANN backend. Depending - * on the tensor type, it may perform data transformations before copying data - * to the device. + * This function asynchronously sets tensor data in the CANN backend. * * @param backend Pointer to the CANN backend structure. * @param tensor Pointer to the tensor structure to set data for. @@ -1487,23 +1808,28 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend, size_t size) { ggml_backend_cann_context *cann_ctx = (ggml_backend_cann_context *)backend->context; + ggml_backend_buffer_t buf = + tensor->view_src ? tensor->view_src->buffer : tensor->buffer; - if (!need_transform(tensor->type)) { - ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data, - size, ACL_MEMCPY_HOST_TO_DEVICE, - cann_ctx->stream())); - } else { - void *transform_buffer = malloc(size); - ggml_backend_cann_transform(tensor, data, transform_buffer); + GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && + "unsupported buffer type"); + GGML_ASSERT(!ggml_is_quantized(tensor->type)); - ACL_CHECK(aclrtMemcpyAsync( - (char *)tensor->data + offset, size, transform_buffer, size, - ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream())); - ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream())); - free(transform_buffer); - } + ggml_cann_async_memcpy(cann_ctx, (char *)tensor->data + offset, data, size, + ACL_MEMCPY_HOST_TO_DEVICE); } +/** + * @brief Gets tensor data asynchronously in the CANN backend. + * + * This function asynchronously gets tensor data in the CANN backend. + * + * @param backend Pointer to the CANN backend structure. + * @param tensor Pointer to the tensor structure to get data from. + * @param data Pointer to the host data to copy from the tensor. + * @param offset Offset in bytes within the host data. + * @param size Size of the data to copy in bytes. + */ static void ggml_backend_cann_get_tensor_async( ggml_backend_t backend, const ggml_tensor *tensor, void *data, size_t offset, size_t size) { @@ -1514,20 +1840,11 @@ static void ggml_backend_cann_get_tensor_async( GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type"); + GGML_ASSERT(!ggml_is_quantized(tensor->type)); + + ggml_cann_async_memcpy(cann_ctx, data, (char *)tensor->data + offset, size, + ACL_MEMCPY_DEVICE_TO_HOST); - if (!need_transform(tensor->type)) { - ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset, - size, ACL_MEMCPY_DEVICE_TO_HOST, - cann_ctx->stream())); - } else { - void *transform_buffer = malloc(size); - ACL_CHECK(aclrtMemcpyAsync( - transform_buffer, size, (char *)tensor->data + offset, size, - ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream())); - ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream())); - ggml_backend_cann_transform_back(tensor, transform_buffer, data); - free(transform_buffer); - } } /** @@ -1587,6 +1904,8 @@ static bool ggml_backend_cann_cpy_tensor_async( ggml_cann_set_device(cann_ctx_src->device); ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0)); + // wait for task_queue empty to keep task order. + cann_ctx_src->task_queue.wait(); ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, cann_ctx_src->stream())); @@ -1614,9 +1933,8 @@ static bool ggml_backend_cann_cpy_tensor_async( static void ggml_backend_cann_synchronize(ggml_backend_t backend) { ggml_backend_cann_context* cann_ctx = (ggml_backend_cann_context*)backend->context; - + cann_ctx->task_queue.wait(); ggml_cann_set_device(cann_ctx->device); - ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream())); } @@ -1675,24 +1993,38 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_NEG: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_STEP: return true; default: return false; } case GGML_OP_MUL_MAT: { switch (op->src[0]->type) { - case GGML_TYPE_Q8_0: case GGML_TYPE_F16: case GGML_TYPE_F32: - case GGML_TYPE_Q4_0: return true; + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: +#ifdef ASCEND_310P + // Q4 && Q8 per group is not suppor on 310p device + return false; +#endif + // only support contiguous for quantized types. + return ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]); default: return false; } @@ -1738,13 +2070,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, } case GGML_OP_ROPE: { // TODO: with ops-test v == 1 - float * ext_factor = (float*)((int32_t*)op->op_params + 7); + float ext_factor = 0.0f; + memcpy(&ext_factor, (const float *) op->op_params + 7, sizeof(float)); // TODO: n_dims <= ne0 if (op->src[0]->ne[0] != op->op_params[1]) { return false; } // TODO: ext_factor != 0 - if (*ext_factor != 0) { + if (ext_factor != 0) { return false; } @@ -1756,6 +2089,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, return false; } + if(!ggml_is_contiguous(op->src[0])){ + return false; + } return true; } case GGML_OP_UPSCALE: { @@ -1764,8 +2100,28 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) { return false; } + if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) { + return false; + } return true; } + case GGML_OP_POOL_2D: { + const int32_t * opts = (const int32_t *) op->op_params; +#ifdef ASCEND_310P + enum ggml_op_pool opt = static_cast(opts[0]); + if(opt == GGML_OP_POOL_MAX){ + return false; + } +#endif + const int k0 = opts[1]; + const int k1 = opts[2]; + const int p0 = opts[5]; + const int p1 = opts[6]; + // value of paddingH should be at most half of kernelH + // value of paddingW should be at most half of kernelW + return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2)); + } + case GGML_OP_SUM: case GGML_OP_DUP: case GGML_OP_IM2COL: case GGML_OP_CONCAT: @@ -1777,15 +2133,17 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_TRANSPOSE: case GGML_OP_NORM: case GGML_OP_ADD: + case GGML_OP_ADD1: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_RMS_NORM: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SQRT: case GGML_OP_CLAMP: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: - case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: @@ -1794,6 +2152,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: + case GGML_OP_ARGMAX: + case GGML_OP_COS: + case GGML_OP_SIN: + case GGML_OP_CONV_TRANSPOSE_1D: + case GGML_OP_LOG: + case GGML_OP_MEAN: + case GGML_OP_PAD_REFLECT_1D: + case GGML_OP_COUNT_EQUAL: return true; default: return false; diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index e73a3b69b5d..6a652738c10 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -222,7 +222,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) elseif (GGML_AVX) list(APPEND ARCH_FLAGS /arch:AVX) list(APPEND ARCH_DEFINITIONS GGML_AVX) - else () + elseif (GGML_SSE42) list(APPEND ARCH_FLAGS /arch:SSE4.2) list(APPEND ARCH_DEFINITIONS GGML_SSE42) endif() @@ -237,8 +237,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_NATIVE) list(APPEND ARCH_FLAGS -march=native) else () - list(APPEND ARCH_FLAGS -msse4.2) - list(APPEND ARCH_DEFINITIONS GGML_SSE42) + if (GGML_SSE42) + list(APPEND ARCH_FLAGS -msse4.2) + list(APPEND ARCH_DEFINITIONS GGML_SSE42) + endif() if (GGML_F16C) list(APPEND ARCH_FLAGS -mf16c) list(APPEND ARCH_DEFINITIONS GGML_F16C) diff --git a/ggml/src/ggml-cpu/cpu-feats-x86.cpp b/ggml/src/ggml-cpu/cpu-feats-x86.cpp index 902ee434666..d775a036385 100644 --- a/ggml/src/ggml-cpu/cpu-feats-x86.cpp +++ b/ggml/src/ggml-cpu/cpu-feats-x86.cpp @@ -263,7 +263,7 @@ void test_x86_is() { static int ggml_backend_cpu_x86_score() { // FIXME: this does not check for OS support - int score = 0; + int score = 1; cpuid_x86 is; #ifdef GGML_FMA diff --git a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp index 74a31abb2d6..175cba329b7 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp @@ -183,67 +183,63 @@ static inline __m256 __avx_rearranged_f32cx8_load(ggml_fp16_t *x, __m128i arrang #if defined(__AVX2__) || defined(__AVX512F__) #if defined(__AVX512F__) -// add int16_t pairwise and return as 512 bit int vector -static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) { +// add int16_t pairwise and return as 512 bit int vector, then add the accumulator +static inline __m512i sum_i16_pairs_acc_int32x16(const __m512i acc, const __m512i x) { const __m512i ones = _mm512_set1_epi16(1); - return _mm512_madd_epi16(ones, x); + return _mm512_add_epi32(acc, _mm512_madd_epi16(ones, x)); } -static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) { +static inline __m512i mul_sum_us8_pairs_acc_int32x16(const __m512i acc, const __m512i ax, const __m512i sy) { #if defined(__AVX512VNNI__) - const __m512i zero = _mm512_setzero_si512(); - return _mm512_dpbusd_epi32(zero, ax, sy); + return _mm512_dpbusd_epi32(acc, ax, sy); #else // Perform multiplication and create 16-bit values const __m512i dot = _mm512_maddubs_epi16(ax, sy); - return sum_i16_pairs_int_32x16(dot); + return sum_i16_pairs_acc_int32x16(acc, dot); #endif } -// multiply int8_t, add results pairwise twice and return as 512 bit int vector -static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y) { +// multiply int8_t, add results pairwise twice and return as 512 bit int vector,then add the accumulator +static inline __m512i mul_sum_i8_pairs_acc_int32x16(const __m512i acc, const __m512i x, const __m512i y) { const __m512i zero = _mm512_setzero_si512(); // Get absolute values of x vectors const __m512i ax = _mm512_abs_epi8(x); // Sign the values of the y vectors __mmask64 blt0 = _mm512_movepi8_mask(x); const __m512i sy = _mm512_mask_sub_epi8(y, blt0, zero, y); - return mul_sum_us8_pairs_int32x16(ax, sy); + return mul_sum_us8_pairs_acc_int32x16(acc, ax, sy); } #endif -// add int16_t pairwise and return as 256 bit int vector -static inline __m256i sum_i16_pairs_int32x8(const __m256i x) { +// add int16_t pairwise and return as 256 bit int vector, then add the accumulator +static inline __m256i sum_i16_pairs_acc_int32x8(const __m256i acc, const __m256i x) { const __m256i ones = _mm256_set1_epi16(1); - return _mm256_madd_epi16(ones, x); + return _mm256_add_epi32(acc, _mm256_madd_epi16(ones, x)); } -static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) { +static inline __m256i mul_sum_us8_pairs_acc_int32x8(const __m256i acc, const __m256i ax, const __m256i sy) { #if defined(__AVX512VNNI__) && defined(__AVX512VL__) - const __m256i zero = _mm256_setzero_si256(); - return _mm256_dpbusd_epi32(zero, ax, sy); + return _mm256_dpbusd_epi32(acc, ax, sy); #elif defined(__AVXVNNI__) - const __m256i zero = _mm256_setzero_si256(); - return _mm256_dpbusd_avx_epi32(zero, ax, sy); + return _mm256_dpbusd_avx_epi32(acc, ax, sy); #else // Perform multiplication and create 16-bit values const __m256i dot = _mm256_maddubs_epi16(ax, sy); - return sum_i16_pairs_int32x8(dot); + return sum_i16_pairs_acc_int32x8(acc, dot); #endif } // Integer variant of the function defined in ggml-quants.c -// multiply int8_t, add results pairwise twice and return as 256 bit int vector -static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) { -#if __AVXVNNIINT8__ - const __m256i zero = _mm256_setzero_si256(); - return _mm256_dpbssd_epi32(zero, x, y); +// multiply int8_t, add results pairwise twice and return as 256 bit int vector, then add the accumulator +static inline __m256i mul_sum_i8_pairs_acc_int32x8(const __m256i acc, const __m256i x, const __m256i y) { +#if defined(__AVXVNNIINT8__) + return _mm256_dpbssd_epi32(acc, x, y); #else // Get absolute values of x vectors const __m256i ax = _mm256_sign_epi8(x, x); // Sign the values of the y vectors const __m256i sy = _mm256_sign_epi8(y, x); - return mul_sum_us8_pairs_int32x8(ax, sy); + return mul_sum_us8_pairs_acc_int32x8(acc, ax, sy); #endif } #endif @@ -1175,17 +1171,17 @@ static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c // ........................................................................... // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31) - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85))); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255))); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85))); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255))); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)); // Accumulated values multipled with appropriate scales acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row); @@ -3239,22 +3235,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // Resembles MMLAs into 2x2 matrices in ARM Version - __m512i iacc_mat_00_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1)); - __m512i iacc_mat_01_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1)); - __m512i iacc_mat_10_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1)); - __m512i iacc_mat_11_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1)); - __m512i iacc_mat_00_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2)); - __m512i iacc_mat_01_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2)); - __m512i iacc_mat_10_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2)); - __m512i iacc_mat_11_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2)); + const __m512i zero = _mm512_setzero_epi32(); + __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1); + __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1); + __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1); + __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1); + __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2); + __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2); + __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2); + __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2); // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); @@ -3430,22 +3419,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // Resembles MMLAs into 2x2 matrices in ARM Version - __m512i iacc_mat_00_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1)); - __m512i iacc_mat_01_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1)); - __m512i iacc_mat_10_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1)); - __m512i iacc_mat_11_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1)); - __m512i iacc_mat_00_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2)); - __m512i iacc_mat_01_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2)); - __m512i iacc_mat_10_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2)); - __m512i iacc_mat_11_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2)); + const __m512i zero = _mm512_setzero_epi32(); + __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1); + __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1); + __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1); + __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1); + __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2); + __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2); + __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2); + __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2); // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); @@ -3605,22 +3587,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // Resembles MMLAs into 2x2 matrices in ARM Version - __m256i iacc_mat_00_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_01_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_10_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_11_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_00_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_01_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); - __m256i iacc_mat_10_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_11_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); + const __m256i zero = _mm256_setzero_si256(); + __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1); + __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1); + __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1); + __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1); + __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2); + __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2); + __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2); + __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2); // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); @@ -3769,22 +3744,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // Resembles MMLAs into 2x2 matrices in ARM Version - __m256i iacc_mat_00_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_01_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_10_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_11_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_00_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_01_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); - __m256i iacc_mat_10_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_11_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); + const __m256i zero = _mm256_setzero_si256(); + __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1); + __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1); + __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1); + __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1); + __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2); + __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2); + __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2); + __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2); // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); @@ -4076,7 +4044,7 @@ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined(__AVX2__) +#if defined(__AVX2__) || defined(__AVX512F__) const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 * ) vx; const block_q8_Kx4 * a_ptr_start = (const block_q8_Kx4 * ) vy; int64_t b_nb = n / QK_K; @@ -4086,8 +4054,748 @@ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c const __m256i m4b = _mm256_set1_epi8(0x0F); // Permute mask used for easier vector processing at later stages __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4); - + int64_t xstart = 0; int anr = nr - nr % 16;; // Used to align nr with boundary of 16 +#ifdef __AVX512F__ + int anc = nc - nc % 16; // Used to align nc with boundary of 16 + // Mask to mask out nibbles from packed bytes expanded to 512 bit length + const __m512i m4bexpanded = _mm512_set1_epi8(0x0F); + //Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation + for (; y < anr / 4; y += 4) { + + const block_q8_Kx4 * a_ptrs[4]; + + a_ptrs[0] = a_ptr_start + (y * nb); + for (int i = 0; i < 3; ++i) { + a_ptrs[i + 1] = a_ptrs[i] + nb; + } + + // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < anc / 8; x += 2) { + + const block_q4_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); + const block_q4_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); + + // Master FP accumulators + __m512 acc_rows[16]; + for (int i = 0; i < 16; i++) { + acc_rows[i] = _mm512_setzero_ps(); + } + + __m512 acc_min_rows[16]; + for (int i = 0; i < 16; i++) { + acc_min_rows[i] = _mm512_setzero_ps(); + } + + // For super block + for (int64_t b = 0; b < nb; b++) { + // Scale values - Load the sixteen scale values from two block_q4_kx8 structures + const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + + // dmin values - Load the sixteen dmin values from two block_q4_kx8 structures + const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin); + + // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration + for (int sb = 0; sb < QK_K / 64; sb++) { + + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256)); + + const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256)); + const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256)); + + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240); + const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240); + + const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); + const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); + const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240); + const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240); + + const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); + const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); + const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); + const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); + + const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1); + const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1); + const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1); + const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1); + + //4-bit -> 8-bit + const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7) + const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7) + const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15) + const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15) + + const __m512i rhs_mat_014589CD_02 = _mm512_and_si512(rhs_raw_mat_014589CD_2, m4bexpanded); //B00(16-23) B01(16-23) B04(16-23) B05(16-23) B08(16-23) B09(16-23) B0C(16-23) B0D(16-23) + const __m512i rhs_mat_2367ABEF_02 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2, m4bexpanded); //B02(16-23) B03(16-23) B06(16-23) B07(16-23) B0A(16-23) B0B(16-23) B0E(16-23) B0F(16-23) + const __m512i rhs_mat_014589CD_03 = _mm512_and_si512(rhs_raw_mat_014589CD_3, m4bexpanded); //B00(24-31) B01(24-31) B04(24-31) B05(24-31) B08(24-31) B09(24-31) B0C(24-31) B0D(24-31) + const __m512i rhs_mat_2367ABEF_03 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3, m4bexpanded); //B02(24-31) B03(24-31) B06(24-31) B07(24-31) B0A(24-31) B0B(24-31) B0E(24-31) B0F(24-31) + + const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7) + const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7) + const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15) + const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15) + + const __m512i rhs_mat_014589CD_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m4bexpanded); //B10(16-23) B11(16-23) B14(16-23) B15(16-23) B18(16-23) B19(16-23) B1C(16-23) B1D(16-23) + const __m512i rhs_mat_2367ABEF_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m4bexpanded); //B12(16-23) B13(16-23) B16(16-23) B17(16-23) B1A(16-23) B1B(16-23) B1E(16-23) B1F(16-23) + const __m512i rhs_mat_014589CD_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m4bexpanded); //B10(24-31) B11(24-31) B14(24-31) B15(24-31) B18(24-31) B19(24-31) B1C(24-31) B1D(24-31) + const __m512i rhs_mat_2367ABEF_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m4bexpanded); //B12(24-31) B13(24-31) B16(24-31) B17(24-31) B1A(24-31) B1B(24-31) B1E(24-31) B1F(24-31) + + // Shuffle pattern one - right side input + const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3) + const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3) + const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11) + const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11) + const __m512i rhs_mat_014589CD_02_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19) B08(16-19) B09(16-19) B08(16-19) B09(16-19) B0C(16-19) B0D(16-19) B0C(16-19) B0D(16-19) + const __m512i rhs_mat_2367ABEF_02_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19) B0A(16-19) B0B(16-19) B0A(16-19) B0B(16-19) B0E(16-19) B0F(16-19) B0E(16-19) B0F(16-19) + const __m512i rhs_mat_014589CD_03_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27) B08(24-27) B09(24-27) B08(24-27) B09(24-27) B0C(24-27) B0D(24-27) B0C(24-27) B0D(24-27) + const __m512i rhs_mat_2367ABEF_03_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27) B0A(24-27) B0B(24-27) B0A(24-27) B0B(24-27) B0E(24-27) B0F(24-27) B0E(24-27) B0F(24-27) + + const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3) + const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3) + const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11) + const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11) + const __m512i rhs_mat_014589CD_12_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19) B18(16-19) B19(16-19) B18(16-19) B19(16-19) B1C(16-19) B1D(16-19) B1C(16-19) B1D(16-19) + const __m512i rhs_mat_2367ABEF_12_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19) B1A(16-19) B1B(16-19) B1A(16-19) B1B(16-19) B1E(16-19) B1F(16-19) B1E(16-19) B1F(16-19) + const __m512i rhs_mat_014589CD_13_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27) B18(24-27) B19(24-27) B18(24-27) B19(24-27) B1C(24-27) B1D(24-27) B1C(24-27) B1D(24-27) + const __m512i rhs_mat_2367ABEF_13_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27) B1A(24-27) B1B(24-27) B1A(24-27) B1B(24-27) B1E(24-27) B1F(24-27) B1E(24-27) B1F(24-27) + + // Shuffle pattern two - right side input + const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7) + const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7) + const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15) + const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15) + const __m512i rhs_mat_014589CD_02_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23) B08(20-23) B09(20-23) B08(20-23) B09(20-23) B0C(20-23) B0D(20-23) B0C(20-23) B0D(20-23) + const __m512i rhs_mat_2367ABEF_02_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23) B0A(20-23) B0B(20-23) B0A(20-23) B0B(20-23) B0E(20-23) B0F(20-23) B0E(20-23) B0F(20-23) + const __m512i rhs_mat_014589CD_03_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31) B08(28-31) B09(28-31) B08(28-31) B09(28-31) B0C(28-31) B0D(28-31) B0C(28-31) 0BD(28-31) + const __m512i rhs_mat_2367ABEF_03_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31) B0A(28-31) B0B(28-31) B0A(28-31) B0B(28-31) B0E(28-31) B0F(28-31) B0E(28-31) B0F(28-31) + + const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7) + const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7) + const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15) + const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15) + const __m512i rhs_mat_014589CD_12_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23) B18(20-23) B19(20-23) B18(20-23) B19(20-23) B1C(20-23) B1D(20-23) B1C(20-23) B1D(20-23) + const __m512i rhs_mat_2367ABEF_12_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23) B1A(20-23) B1B(20-23) B1A(20-23) B1B(20-23) B1E(20-23) B1F(20-23) B1E(20-23) B1F(20-23) + const __m512i rhs_mat_014589CD_13_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31) B18(28-31) B19(28-31) B18(28-31) B19(28-31) B1C(28-31) B1D(28-31) B1C(28-31) B1D(28-31) + const __m512i rhs_mat_2367ABEF_13_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31) B1A(28-31) B1B(28-31) B1A(28-31) B1B(28-31) B1E(28-31) B1F(28-31) B1E(28-31) B1F(28-31) + + uint32_t utmp_00[4], utmp_01[4], utmp_10[4], utmp_11[4]; + + // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together + // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_00, b_ptr_0[b].scales + 24 * sb, 12); + utmp_00[3] = ((utmp_00[2] >> 4) & kmask2) | (((utmp_00[1] >> 6) & kmask3) << 4); + const uint32_t uaux_00 = utmp_00[1] & kmask1; + utmp_00[1] = (utmp_00[2] & kmask2) | (((utmp_00[0] >> 6) & kmask3) << 4); + utmp_00[2] = uaux_00; + utmp_00[0] &= kmask1; + + // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_01, b_ptr_0[b].scales + 12 + sb * 24, 12); + utmp_01[3] = ((utmp_01[2] >> 4) & kmask2) | (((utmp_01[1] >> 6) & kmask3) << 4); + const uint32_t uaux_01 = utmp_01[1] & kmask1; + utmp_01[1] = (utmp_01[2] & kmask2) | (((utmp_01[0] >> 6) & kmask3) << 4); + utmp_01[2] = uaux_01; + utmp_01[0] &= kmask1; + + memcpy(utmp_10, b_ptr_1[b].scales + sb * 24, 12); + utmp_10[3] = ((utmp_10[2] >> 4) & kmask2) | (((utmp_10[1] >> 6) & kmask3) << 4); + const uint32_t uaux_10 = utmp_10[1] & kmask1; + utmp_10[1] = (utmp_10[2] & kmask2) | (((utmp_10[0] >> 6) & kmask3) << 4); + utmp_10[2] = uaux_10; + utmp_10[0] &= kmask1; + + // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_11, b_ptr_1[b].scales + 12 + sb * 24, 12); + utmp_11[3] = ((utmp_11[2] >> 4) & kmask2) | (((utmp_11[1] >> 6) & kmask3) << 4); + const uint32_t uaux_11 = utmp_11[1] & kmask1; + utmp_11[1] = (utmp_11[2] & kmask2) | (((utmp_11[0] >> 6) & kmask3) << 4); + utmp_11[2] = uaux_11; + utmp_11[0] &= kmask1; + + // Scales of first sub block in the sb loop + const __m256i mins_and_scales_0 = _mm256_set_epi32(utmp_10[3], utmp_10[2], utmp_10[1], utmp_10[0], utmp_00[3], utmp_00[2], utmp_00[1], utmp_00[0]); + const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0)); + + // Scales of second sub block in the sb loop + const __m256i mins_and_scales_1 = _mm256_set_epi32(utmp_11[3], utmp_11[2], utmp_11[1], utmp_11[0], utmp_01[3], utmp_01[2], utmp_01[1], utmp_01[0]); + const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1)); + + // Mins of first and second sub block of Q4_K block are arranged side by side + const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(_mm256_shuffle_epi32(mins_and_scales_0, 78), _mm256_shuffle_epi32(mins_and_scales_1, 78))); + + const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238); + + for (int rp = 0; rp < 4; rp++) { + + // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector + __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb))); + __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0); + __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17); + __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb))); + __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0); + __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17); + __m256i lhs_mat_ymm_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb))); + __m256i lhs_mat_ymm_01_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 0); + __m256i lhs_mat_ymm_23_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 17); + __m256i lhs_mat_ymm_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb))); + __m256i lhs_mat_ymm_01_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 0); + __m256i lhs_mat_ymm_23_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 17); + __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb))); + __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0); + __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17); + __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb))); + __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0); + __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17); + __m256i lhs_mat_ymm_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb))); + __m256i lhs_mat_ymm_01_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 0); + __m256i lhs_mat_ymm_23_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 17); + __m256i lhs_mat_ymm_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb))); + __m256i lhs_mat_ymm_01_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 0); + __m256i lhs_mat_ymm_23_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 17); + + __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1); + __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1); + __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1); + __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1); + __m512i lhs_mat_01_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_02), lhs_mat_ymm_01_02, 1); + __m512i lhs_mat_23_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_02), lhs_mat_ymm_23_02, 1); + __m512i lhs_mat_01_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_03), lhs_mat_ymm_01_03, 1); + __m512i lhs_mat_23_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_03), lhs_mat_ymm_23_03, 1); + + __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1); + __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1); + __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1); + __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1); + __m512i lhs_mat_01_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_12), lhs_mat_ymm_01_12, 1); + __m512i lhs_mat_23_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_12), lhs_mat_ymm_23_12, 1); + __m512i lhs_mat_01_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_13), lhs_mat_ymm_01_13, 1); + __m512i lhs_mat_23_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_13), lhs_mat_ymm_23_13, 1); + + // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks + __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb))); + __m256i lhs_bsums_hsum_ymm_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1))); + lhs_bsums_hsum_ymm_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_ymm_0123_01, lhs_bsums_hsum_ymm_0123_01, 0); + __m512i lhs_bsums_hsum_0123_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_hsum_ymm_0123_01), lhs_bsums_hsum_ymm_0123_01, 1); + + // Shuffle pattern one - left side input + const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) + const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) + const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) + const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) + const __m512i lhs_mat_01_02_sp1 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) + const __m512i lhs_mat_23_02_sp1 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)160); //A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) + const __m512i lhs_mat_01_03_sp1 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) + const __m512i lhs_mat_23_03_sp1 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)160); //A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) + + const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) + const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) + const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) + const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) + const __m512i lhs_mat_01_12_sp1 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) + const __m512i lhs_mat_23_12_sp1 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)160); //A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) + const __m512i lhs_mat_01_13_sp1 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) + const __m512i lhs_mat_23_13_sp1 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)160); //A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) + + const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) + const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) + const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) + const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) + const __m512i lhs_mat_01_02_sp2 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) + const __m512i lhs_mat_23_02_sp2 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)245); //A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) + const __m512i lhs_mat_01_03_sp2 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) + const __m512i lhs_mat_23_03_sp2 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)245); //A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) + + const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) + const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) + const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) + const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) + const __m512i lhs_mat_01_12_sp2 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) + const __m512i lhs_mat_23_12_sp2 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)245); //A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) + const __m512i lhs_mat_01_13_sp2 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) + const __m512i lhs_mat_23_13_sp2 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)245); //A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1)); + __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1)); + __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1)); + __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1)); + __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1)); + __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1)); + __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1)); + __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1)); + + __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2)); + __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2)); + __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2)); + __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2)); + __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2)); + __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2)); + __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2)); + __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2); + __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2); + __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2); + __m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2); + + __m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2); + __m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2); + __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2); + __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2); + + iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0); + iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0); + iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0); + iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0); + + iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1); + iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1); + iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1); + iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1); + + // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step) + __m512i iacc_row_0_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_0, _mm512_shuffle_epi32(iacc_mat_01_0, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_0, (_MM_PERM_ENUM)78), iacc_mat_01_0); + __m512i iacc_row_2_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_0, _mm512_shuffle_epi32(iacc_mat_11_0, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10_0, (_MM_PERM_ENUM)78), iacc_mat_11_0); + __m512i iacc_row_0_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_1, _mm512_shuffle_epi32(iacc_mat_01_1, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_1, (_MM_PERM_ENUM)78), iacc_mat_01_1); + __m512i iacc_row_2_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_1, _mm512_shuffle_epi32(iacc_mat_11_1, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3_1 = _mm512_mask_blend_epi32(0xCCCC,_mm512_shuffle_epi32(iacc_mat_10_1, (_MM_PERM_ENUM)78), iacc_mat_11_1); + + __m512i iacc_row_0 = _mm512_add_epi32(iacc_row_0_0, iacc_row_0_1); + __m512i iacc_row_1 = _mm512_add_epi32(iacc_row_1_0, iacc_row_1_1); + __m512i iacc_row_2 = _mm512_add_epi32(iacc_row_2_0, iacc_row_2_1); + __m512i iacc_row_3 = _mm512_add_epi32(iacc_row_3_0, iacc_row_3_1); + + // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes + const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d); + const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); + const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1); + + // Multiply with appropiate scales and accumulate (for both d and dmin) below + acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); + acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); + acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); + acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); + + __m512i iacc_row_min_0 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)0), mins_01); + __m512i iacc_row_min_1 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)85), mins_01); + __m512i iacc_row_min_2 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)170), mins_01); + __m512i iacc_row_min_3 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)255), mins_01); + + acc_min_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]); + acc_min_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]); + acc_min_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]); + acc_min_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]); + } + } + } + // Store the accumulated values + for (int i = 0; i < 16; i++) { + _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i])); + } + } + } + + for (; y < nr / 4; y++) { + + const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb); + + // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < anc / 8; x += 2) { + + const block_q4_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); + const block_q4_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); + + // Master FP accumulators + __m512 acc_rows[4]; + for (int i = 0; i < 4; i++) { + acc_rows[i] = _mm512_setzero_ps(); + } + + __m512 acc_min_rows[4]; + for (int i = 0; i < 4; i++) { + acc_min_rows[i] = _mm512_setzero_ps(); + } + + // For super block + for (int64_t b = 0; b < nb; b++) { + // Scale values - Load the sixteen scale values from two block_q4_kx8 structures + const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + + // dmin values - Load the sixteen dmin values from two block_q4_kx8 structures + const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin); + + // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration + for (int sb = 0; sb < QK_K / 64; sb++) { + + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256)); + + const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256)); + const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256)); + + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240); + const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240); + + const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); + const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); + const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240); + const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240); + + const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); + const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); + const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); + const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); + + const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1); + const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1); + const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1); + const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1); + + //4-bit -> 8-bit + const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7) + const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7) + const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15) + const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15) + + const __m512i rhs_mat_014589CD_02 = _mm512_and_si512(rhs_raw_mat_014589CD_2, m4bexpanded); //B00(16-23) B01(16-23) B04(16-23) B05(16-23) B08(16-23) B09(16-23) B0C(16-23) B0D(16-23) + const __m512i rhs_mat_2367ABEF_02 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2, m4bexpanded); //B02(16-23) B03(16-23) B06(16-23) B07(16-23) B0A(16-23) B0B(16-23) B0E(16-23) B0F(16-23) + const __m512i rhs_mat_014589CD_03 = _mm512_and_si512(rhs_raw_mat_014589CD_3, m4bexpanded); //B00(24-31) B01(24-31) B04(24-31) B05(24-31) B08(24-31) B09(24-31) B0C(24-31) B0D(24-31) + const __m512i rhs_mat_2367ABEF_03 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3, m4bexpanded); //B02(24-31) B03(24-31) B06(24-31) B07(24-31) B0A(24-31) B0B(24-31) B0E(24-31) B0F(24-31) + + const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7) + const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7) + const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15) + const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15) + + const __m512i rhs_mat_014589CD_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m4bexpanded); //B10(16-23) B11(16-23) B14(16-23) B15(16-23) B18(16-23) B19(16-23) B1C(16-23) B1D(16-23) + const __m512i rhs_mat_2367ABEF_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m4bexpanded); //B12(16-23) B13(16-23) B16(16-23) B17(16-23) B1A(16-23) B1B(16-23) B1E(16-23) B1F(16-23) + const __m512i rhs_mat_014589CD_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m4bexpanded); //B10(24-31) B11(24-31) B14(24-31) B15(24-31) B18(24-31) B19(24-31) B1C(24-31) B1D(24-31) + const __m512i rhs_mat_2367ABEF_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m4bexpanded); //B12(24-31) B13(24-31) B16(24-31) B17(24-31) B1A(24-31) B1B(24-31) B1E(24-31) B1F(24-31) + + // Shuffle pattern one - right side input + const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3) + const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3) + const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11) + const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11) + const __m512i rhs_mat_014589CD_02_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19) B08(16-19) B09(16-19) B08(16-19) B09(16-19) B0C(16-19) B0D(16-19) B0C(16-19) B0D(16-19) + const __m512i rhs_mat_2367ABEF_02_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19) B0A(16-19) B0B(16-19) B0A(16-19) B0B(16-19) B0E(16-19) B0F(16-19) B0E(16-19) B0F(16-19) + const __m512i rhs_mat_014589CD_03_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27) B08(24-27) B09(24-27) B08(24-27) B09(24-27) B0C(24-27) B0D(24-27) B0C(24-27) B0D(24-27) + const __m512i rhs_mat_2367ABEF_03_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27) B0A(24-27) B0B(24-27) B0A(24-27) B0B(24-27) B0E(24-27) B0F(24-27) B0E(24-27) B0F(24-27) + + const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3) + const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3) + const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11) + const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11) + const __m512i rhs_mat_014589CD_12_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19) B18(16-19) B19(16-19) B18(16-19) B19(16-19) B1C(16-19) B1D(16-19) B1C(16-19) B1D(16-19) + const __m512i rhs_mat_2367ABEF_12_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19) B1A(16-19) B1B(16-19) B1A(16-19) B1B(16-19) B1E(16-19) B1F(16-19) B1E(16-19) B1F(16-19) + const __m512i rhs_mat_014589CD_13_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27) B18(24-27) B19(24-27) B18(24-27) B19(24-27) B1C(24-27) B1D(24-27) B1C(24-27) B1D(24-27) + const __m512i rhs_mat_2367ABEF_13_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27) B1A(24-27) B1B(24-27) B1A(24-27) B1B(24-27) B1E(24-27) B1F(24-27) B1E(24-27) B1F(24-27) + + // Shuffle pattern two - right side input + const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7) + const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7) + const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15) + const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15) + const __m512i rhs_mat_014589CD_02_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23) B08(20-23) B09(20-23) B08(20-23) B09(20-23) B0C(20-23) B0D(20-23) B0C(20-23) B0D(20-23) + const __m512i rhs_mat_2367ABEF_02_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23) B0A(20-23) B0B(20-23) B0A(20-23) B0B(20-23) B0E(20-23) B0F(20-23) B0E(20-23) B0F(20-23) + const __m512i rhs_mat_014589CD_03_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31) B08(28-31) B09(28-31) B08(28-31) B09(28-31) B0C(28-31) B0D(28-31) B0C(28-31) 0BD(28-31) + const __m512i rhs_mat_2367ABEF_03_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31) B0A(28-31) B0B(28-31) B0A(28-31) B0B(28-31) B0E(28-31) B0F(28-31) B0E(28-31) B0F(28-31) + + const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7) + const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7) + const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15) + const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15) + const __m512i rhs_mat_014589CD_12_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23) B18(20-23) B19(20-23) B18(20-23) B19(20-23) B1C(20-23) B1D(20-23) B1C(20-23) B1D(20-23) + const __m512i rhs_mat_2367ABEF_12_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23) B1A(20-23) B1B(20-23) B1A(20-23) B1B(20-23) B1E(20-23) B1F(20-23) B1E(20-23) B1F(20-23) + const __m512i rhs_mat_014589CD_13_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31) B18(28-31) B19(28-31) B18(28-31) B19(28-31) B1C(28-31) B1D(28-31) B1C(28-31) B1D(28-31) + const __m512i rhs_mat_2367ABEF_13_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31) B1A(28-31) B1B(28-31) B1A(28-31) B1B(28-31) B1E(28-31) B1F(28-31) B1E(28-31) B1F(28-31) + + uint32_t utmp_00[4], utmp_01[4], utmp_10[4], utmp_11[4]; + + // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together + // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_00, b_ptr_0[b].scales + 24 * sb, 12); + utmp_00[3] = ((utmp_00[2] >> 4) & kmask2) | (((utmp_00[1] >> 6) & kmask3) << 4); + const uint32_t uaux_00 = utmp_00[1] & kmask1; + utmp_00[1] = (utmp_00[2] & kmask2) | (((utmp_00[0] >> 6) & kmask3) << 4); + utmp_00[2] = uaux_00; + utmp_00[0] &= kmask1; + + // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_01, b_ptr_0[b].scales + 12 + sb * 24, 12); + utmp_01[3] = ((utmp_01[2] >> 4) & kmask2) | (((utmp_01[1] >> 6) & kmask3) << 4); + const uint32_t uaux_01 = utmp_01[1] & kmask1; + utmp_01[1] = (utmp_01[2] & kmask2) | (((utmp_01[0] >> 6) & kmask3) << 4); + utmp_01[2] = uaux_01; + utmp_01[0] &= kmask1; + + // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_10, b_ptr_1[b].scales + sb * 24, 12); + utmp_10[3] = ((utmp_10[2] >> 4) & kmask2) | (((utmp_10[1] >> 6) & kmask3) << 4); + const uint32_t uaux_10 = utmp_10[1] & kmask1; + utmp_10[1] = (utmp_10[2] & kmask2) | (((utmp_10[0] >> 6) & kmask3) << 4); + utmp_10[2] = uaux_10; + utmp_10[0] &= kmask1; + + // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_11, b_ptr_1[b].scales + 12 + sb * 24, 12); + utmp_11[3] = ((utmp_11[2] >> 4) & kmask2) | (((utmp_11[1] >> 6) & kmask3) << 4); + const uint32_t uaux_11 = utmp_11[1] & kmask1; + utmp_11[1] = (utmp_11[2] & kmask2) | (((utmp_11[0] >> 6) & kmask3) << 4); + utmp_11[2] = uaux_11; + utmp_11[0] &= kmask1; + + // Scales of first sub block in the sb loop + const __m256i mins_and_scales_0 = _mm256_set_epi32(utmp_10[3], utmp_10[2], utmp_10[1], utmp_10[0], utmp_00[3], utmp_00[2], utmp_00[1], utmp_00[0]); + const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0)); + + // Scales of second sub block in the sb loop + const __m256i mins_and_scales_1 = _mm256_set_epi32(utmp_11[3], utmp_11[2], utmp_11[1], utmp_11[0], utmp_01[3], utmp_01[2], utmp_01[1], utmp_01[0]); + const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1)); + + // Mins of first and second sub block of Q4_K block are arranged side by side + const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(_mm256_shuffle_epi32(mins_and_scales_0, 78), _mm256_shuffle_epi32(mins_and_scales_1, 78))); + + const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238); + + // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated into a 256 bit vector + __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 * sb))); + __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0); + __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17); + __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 256 * sb))); + __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0); + __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17); + __m256i lhs_mat_ymm_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 256 * sb))); + __m256i lhs_mat_ymm_01_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 0); + __m256i lhs_mat_ymm_23_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 17); + __m256i lhs_mat_ymm_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 256 * sb))); + __m256i lhs_mat_ymm_01_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 0); + __m256i lhs_mat_ymm_23_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 17); + __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 256 * sb))); + __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0); + __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17); + __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 256 * sb))); + __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0); + __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17); + __m256i lhs_mat_ymm_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 256 * sb))); + __m256i lhs_mat_ymm_01_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 0); + __m256i lhs_mat_ymm_23_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 17); + __m256i lhs_mat_ymm_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 256 * sb))); + __m256i lhs_mat_ymm_01_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 0); + __m256i lhs_mat_ymm_23_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 17); + + //Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into a 512 bit vector + __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1); + __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1); + __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1); + __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1); + __m512i lhs_mat_01_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_02), lhs_mat_ymm_01_02, 1); + __m512i lhs_mat_23_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_02), lhs_mat_ymm_23_02, 1); + __m512i lhs_mat_01_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_03), lhs_mat_ymm_01_03, 1); + __m512i lhs_mat_23_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_03), lhs_mat_ymm_23_03, 1); + + __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1); + __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1); + __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1); + __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1); + __m512i lhs_mat_01_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_12), lhs_mat_ymm_01_12, 1); + __m512i lhs_mat_23_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_12), lhs_mat_ymm_23_12, 1); + __m512i lhs_mat_01_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_13), lhs_mat_ymm_01_13, 1); + __m512i lhs_mat_23_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_13), lhs_mat_ymm_23_13, 1); + + // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks + __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].bsums + 16 * sb))); + __m256i lhs_bsums_hsum_ymm_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1))); + lhs_bsums_hsum_ymm_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_ymm_0123_01, lhs_bsums_hsum_ymm_0123_01, 0); + __m512i lhs_bsums_hsum_0123_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_hsum_ymm_0123_01), lhs_bsums_hsum_ymm_0123_01, 1); + + // Shuffle pattern one - left side input + const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) + const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) + const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) + const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) + const __m512i lhs_mat_01_02_sp1 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) + const __m512i lhs_mat_23_02_sp1 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)160); //A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) + const __m512i lhs_mat_01_03_sp1 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) + const __m512i lhs_mat_23_03_sp1 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)160); //A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) + + const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) + const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) + const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) + const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) + const __m512i lhs_mat_01_12_sp1 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) + const __m512i lhs_mat_23_12_sp1 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)160); //A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) + const __m512i lhs_mat_01_13_sp1 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) + const __m512i lhs_mat_23_13_sp1 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)160); //A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) + + const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) + const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) + const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) + const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) + const __m512i lhs_mat_01_02_sp2 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) + const __m512i lhs_mat_23_02_sp2 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)245); //A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) + const __m512i lhs_mat_01_03_sp2 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) + const __m512i lhs_mat_23_03_sp2 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)245); //A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) + + const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) + const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) + const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) + const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) + const __m512i lhs_mat_01_12_sp2 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) + const __m512i lhs_mat_23_12_sp2 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)245); //A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) + const __m512i lhs_mat_01_13_sp2 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) + const __m512i lhs_mat_23_13_sp2 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)245); //A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1)); + __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1)); + __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1)); + __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1)); + __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1)); + __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1)); + __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1)); + __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1)); + + __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2)); + __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2)); + __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2)); + __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2)); + __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2)); + __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2)); + __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2)); + __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2); + __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2); + __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2); + __m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2); + + __m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2); + __m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2); + __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2); + __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2); + + iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0); + iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0); + iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0); + iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0); + + iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1); + iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1); + iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1); + iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1); + + // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step) + __m512i iacc_row_0_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_0, _mm512_shuffle_epi32(iacc_mat_01_0, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_0, (_MM_PERM_ENUM)78), iacc_mat_01_0); + __m512i iacc_row_2_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_0, _mm512_shuffle_epi32(iacc_mat_11_0, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10_0, (_MM_PERM_ENUM)78), iacc_mat_11_0); + __m512i iacc_row_0_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_1, _mm512_shuffle_epi32(iacc_mat_01_1, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_1, (_MM_PERM_ENUM)78), iacc_mat_01_1); + __m512i iacc_row_2_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_1, _mm512_shuffle_epi32(iacc_mat_11_1, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3_1 = _mm512_mask_blend_epi32(0xCCCC,_mm512_shuffle_epi32(iacc_mat_10_1, (_MM_PERM_ENUM)78), iacc_mat_11_1); + + __m512i iacc_row_0 = _mm512_add_epi32(iacc_row_0_0, iacc_row_0_1); + __m512i iacc_row_1 = _mm512_add_epi32(iacc_row_1_0, iacc_row_1_1); + __m512i iacc_row_2 = _mm512_add_epi32(iacc_row_2_0, iacc_row_2_1); + __m512i iacc_row_3 = _mm512_add_epi32(iacc_row_3_0, iacc_row_3_1); + + // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes + const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d); + const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); + const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1); + + // Multiply with appropiate scales and accumulate (for both d and dmin) below + acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); + acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); + acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); + acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); + + __m512i iacc_row_min_0 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)0), mins_01); + __m512i iacc_row_min_1 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)85), mins_01); + __m512i iacc_row_min_2 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)170), mins_01); + __m512i iacc_row_min_3 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)255), mins_01); + + acc_min_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]); + acc_min_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]); + acc_min_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]); + acc_min_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]); + } + } + // Store accumlated values + for (int i = 0; i < 4; i++) { + _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i])); + } + } + } + if (anc != nc) { + xstart = anc/8; + y = 0; + } +#endif //AVX512F + // Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation for (; y < anr / 4; y += 4) { @@ -4099,7 +4807,7 @@ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c } // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation - for (int64_t x = 0; x < nc / 8; x++) { + for (int64_t x = xstart; x < nc / 8; x++) { const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb); @@ -4433,7 +5141,7 @@ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb); - for (int64_t x = 0; x < nc / 8; x++) { + for (int64_t x = xstart; x < nc / 8; x++) { const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb); diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index 7f7d210cbe5..e4af07635c1 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -4,13 +4,13 @@ #include "ggml.h" #include "ggml-impl.h" + #include // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/ //#include #include #include // memcpy #include // fabsf - #ifdef __cplusplus extern "C" { #endif @@ -69,33 +69,16 @@ struct ggml_compute_params { #endif #if defined(__ARM_FEATURE_SVE) -#include #include #endif -// 16-bit float -// on Arm, we use __fp16 -// on x86, we use uint16_t #if defined(__ARM_NEON) -// if YCM cannot find , make a symbolic link to it, for example: -// -// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ -// -#include - +// ref: https://github.com/ggml-org/llama.cpp/pull/5404 #ifdef _MSC_VER - -typedef uint16_t ggml_fp16_internal_t; - #define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) } - #else - -typedef __fp16 ggml_fp16_internal_t; - #define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) } - #endif // _MSC_VER #if !defined(__aarch64__) @@ -340,8 +323,6 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) #else #ifdef __POWER9_VECTOR__ #include -#undef bool -#define bool _Bool #else #if defined(_MSC_VER) || defined(__MINGW32__) #include diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 34618c27aa4..dbad8f61a1e 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1932,6 +1932,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_im2col_back_f32(params, tensor); } break; + case GGML_OP_CONV_2D_DW: + { + ggml_compute_forward_conv_2d_dw(params, tensor); + } break; case GGML_OP_CONV_TRANSPOSE_2D: { ggml_compute_forward_conv_transpose_2d(params, tensor); @@ -2027,41 +2031,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_rwkv_wkv7(params, tensor); } break; - case GGML_OP_MAP_UNARY: - { - ggml_unary_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_unary(params, tensor, fun); - } - break; - case GGML_OP_MAP_BINARY: - { - ggml_binary_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_binary(params, tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM1_F32: - { - ggml_custom1_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom1_f32(params, tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM2_F32: - { - ggml_custom2_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom2_f32(params, tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM3_F32: - { - ggml_custom3_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom3_f32(params, tensor, fun); - } - break; case GGML_OP_MAP_CUSTOM1: { ggml_compute_forward_map_custom1(params, tensor); @@ -2077,6 +2046,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm ggml_compute_forward_map_custom3(params, tensor); } break; + case GGML_OP_CUSTOM: + { + ggml_compute_forward_custom(params, tensor); + } + break; case GGML_OP_CROSS_ENTROPY_LOSS: { ggml_compute_forward_cross_entropy_loss(params, tensor); @@ -2298,6 +2272,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_IM2COL: case GGML_OP_IM2COL_BACK: + case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_2D: { @@ -2328,11 +2303,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: case GGML_OP_GET_REL_POS: - case GGML_OP_MAP_UNARY: - case GGML_OP_MAP_BINARY: - case GGML_OP_MAP_CUSTOM1_F32: - case GGML_OP_MAP_CUSTOM2_F32: - case GGML_OP_MAP_CUSTOM3_F32: { n_tasks = 1; } break; @@ -2366,6 +2336,16 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = MIN(p.n_tasks, n_threads); } } break; + case GGML_OP_CUSTOM: + { + struct ggml_custom_op_params p; + memcpy(&p, node->op_params, sizeof(p)); + if (p.n_tasks == GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p.n_tasks, n_threads); + } + } break; case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_OPT_STEP_ADAMW: diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 09f8382b988..4b688a67eb2 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -425,6 +425,8 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st } case GGML_OP_IM2COL_BACK: return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; + case GGML_OP_GET_ROWS_BACK: + return src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16; case GGML_OP_OUT_PROD: return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) && src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 7a8d5ac6fd9..3c2adb21726 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6064,6 +6064,178 @@ void ggml_compute_forward_conv_transpose_2d( } } +// ggml_compute_forward_conv_2d_dw + +struct ggml_conv_2d_dw_params { + int64_t channels; + int64_t batch; + int64_t src_w; + int64_t src_h; + int64_t dst_w; + int64_t dst_h; + int64_t knl_w; + int64_t knl_h; + int stride_x; + int stride_y; + int pad_x; + int pad_y; + int dilation_x; + int dilation_y; +}; + +static void ggml_compute_forward_conv_2d_dw_cwhn( + const ggml_compute_params * params, + const ggml_tensor * src, + const ggml_tensor * kernel, + ggml_tensor * dst, + const ggml_conv_2d_dw_params & p) { + + const int64_t c = p.channels; + const float * knl_data = (const float *)kernel->data; + + const int64_t rows_total = p.dst_h * p.batch; + const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth; + const int64_t row_start = params->ith * rows_per_thread; + const int64_t row_end = MIN(row_start + rows_per_thread, rows_total); + +#ifdef GGML_SIMD + const int64_t pkg_size = GGML_F32_EPR; + const int64_t pkg_count = c / pkg_size; + const int64_t c_pkg_end = pkg_count * pkg_size; +#else + const int64_t c_pkg_end = 0; +#endif + + for (int64_t row = row_start; row < row_end; ++row) { + const int64_t dst_y = row % p.dst_h; + const float * src_data = (const float *)src->data + (row / p.dst_h) * p.src_w * p.src_h * c; + for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) { + float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c; + const int64_t src_y_base = dst_y * p.stride_y - p.pad_y; + const int64_t src_x_base = dst_x * p.stride_x - p.pad_x; + +#ifdef GGML_SIMD + // Vectorized loop + for (int64_t c_i = 0; c_i < c_pkg_end; c_i += pkg_size) { + GGML_F32_VEC sum = GGML_F32_VEC_ZERO; + for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) { + const int64_t src_y = src_y_base + knl_y * p.dilation_y; + if (src_y < 0 || src_y >= p.src_h) { + continue; + } + for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) { + const int64_t src_x = src_x_base + knl_x * p.dilation_x; + if (src_x < 0 || src_x >= p.src_w) { + continue; + } + GGML_F32_VEC k = GGML_F32_VEC_LOAD(knl_data + (knl_y * p.knl_w + knl_x) * c + c_i); + GGML_F32_VEC s = GGML_F32_VEC_LOAD(src_data + (src_y * p.src_w + src_x) * c + c_i); + sum = GGML_F32_VEC_FMA(sum, k, s); + } + } + GGML_F32_VEC_STORE(dst_data + c_i, sum); + } +#endif + // Scalar loop + for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) { + float sum = 0.0f; + for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) { + const int64_t src_y = src_y_base + knl_y * p.dilation_y; + if (src_y < 0 || src_y >= p.src_h) { + continue; + } + for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) { + const int64_t src_x = src_x_base + knl_x * p.dilation_x; + if (src_x < 0 || src_x >= p.src_w) { + continue; + } + sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i] + * src_data[(src_y * p.src_w + src_x) * c + c_i]; + } + } + dst_data[c_i] = sum; + } + } + } +} + +static void ggml_compute_forward_conv_2d_dw_whcn( + const ggml_compute_params * params, + const ggml_tensor * src, + const ggml_tensor * kernel, + ggml_tensor * dst, + const ggml_conv_2d_dw_params & p) { + + const int64_t n = p.channels * p.batch; + const int64_t per_thread = (n + params->nth - 1) / params->nth; + const int64_t start = params->ith * per_thread; + const int64_t end = MIN(start + per_thread, n); + + for (int64_t i = start; i < end; ++i) { + const float * knl_data = (const float *)kernel->data + (i % p.channels) * p.knl_w * p.knl_h; + const float * src_data = (const float *)src->data + i * p.src_w * p.src_h; + float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h; + + for (int64_t dst_y = 0; dst_y < p.dst_h; ++dst_y) { + for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) { + + float sum = 0.0f; + for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) { + const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y < 0 || src_y >= p.src_h) { + continue; + } + for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) { + const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x < 0 || src_x >= p.src_w) { + continue; + } + sum += knl_data[knl_y * p.knl_w + knl_x] + * src_data[src_y * p.src_w + src_x]; + } + } + dst_data[dst_y * p.dst_w + dst_x] = sum; + } + } + } +} + +void ggml_compute_forward_conv_2d_dw( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * kernel = dst->src[0]; + const ggml_tensor * src = dst->src[1]; + ggml_conv_2d_dw_params p; + p.channels = src->ne[2]; + p.batch = src->ne[3]; + p.src_w = src->ne[0]; + p.src_h = src->ne[1]; + p.dst_w = dst->ne[0]; + p.dst_h = dst->ne[1]; + p.knl_w = kernel->ne[0]; + p.knl_h = kernel->ne[1]; + p.stride_x = dst->op_params[0]; + p.stride_y = dst->op_params[1]; + p.pad_x = dst->op_params[2]; + p.pad_y = dst->op_params[3]; + p.dilation_x = dst->op_params[4]; + p.dilation_y = dst->op_params[5]; + + GGML_ASSERT(kernel->ne[3] == p.channels); + GGML_ASSERT(dst->ne[3] == p.batch); + + if (ggml_is_contiguous(src)) { + ggml_compute_forward_conv_2d_dw_whcn(params, src, kernel, dst, p); + } else if (ggml_is_contiguous_channels(src)) { + // kernel should also have channels most contiguous in memory + GGML_ASSERT(kernel->nb[0] >= kernel->nb[2] && kernel->nb[1] >= kernel->nb[0]); + ggml_compute_forward_conv_2d_dw_cwhn(params, src, kernel, dst, p); + } else { + GGML_ABORT("non-contiguous memory layout not supported"); + } +} + // ggml_compute_forward_pool_1d_sk_p0 static void ggml_compute_forward_pool_1d_sk_p0( @@ -6351,24 +6523,72 @@ static void ggml_compute_forward_upscale_f32( const float sf2 = (float)ne2/src0->ne[2]; const float sf3 = (float)ne3/src0->ne[3]; - // TODO: optimize + const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0); - for (int64_t i3 = 0; i3 < ne3; i3++) { - const int64_t i03 = i3 / sf3; - for (int64_t i2 = ith; i2 < ne2; i2 += nth) { - const int64_t i02 = i2 / sf2; - for (int64_t i1 = 0; i1 < ne1; i1++) { - const int64_t i01 = i1 / sf1; - for (int64_t i0 = 0; i0 < ne0; i0++) { - const int64_t i00 = i0 / sf0; + if (mode == GGML_SCALE_MODE_NEAREST) { + for (int64_t i3 = 0; i3 < ne3; i3++) { + const int64_t i03 = i3 / sf3; + for (int64_t i2 = ith; i2 < ne2; i2 += nth) { + const int64_t i02 = i2 / sf2; + for (int64_t i1 = 0; i1 < ne1; i1++) { + const int64_t i01 = i1 / sf1; + for (int64_t i0 = 0; i0 < ne0; i0++) { + const int64_t i00 = i0 / sf0; - const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); - *y = *x; + *y = *x; + } + } + } + } + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + // setting a pixel offset of 0 would replicate the behavior of pytorch interpolate with align_corners=True + const float pixel_offset = 0.5f; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + const int64_t i03 = i3 / sf3; + for (int64_t i2 = ith; i2 < ne2; i2 += nth) { + const int64_t i02 = i2 / sf2; + for (int64_t i1 = 0; i1 < ne1; i1++) { + const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset; + int64_t y0 = (int64_t)floorf(y); + int64_t y1 = y0 + 1; + + y0 = std::max(int64_t(0), std::min(y0, ne01 - 1)); + y1 = std::max(int64_t(0), std::min(y1, ne01 - 1)); + + float dy = y - (float)y0; + dy = std::max(0.0f, std::min(dy, 1.0f)); + + for (int64_t i0 = 0; i0 < ne0; i0++) { + const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset; + int64_t x0 = (int64_t)floorf(x); + int64_t x1 = x0 + 1; + + x0 = std::max(int64_t(0), std::min(x0, ne00 - 1)); + x1 = std::max(int64_t(0), std::min(x1, ne00 - 1)); + + float dx = x - (float)x0; + dx = std::max(0.0f, std::min(dx, 1.0f)); + + // fetch the four surrounding pixel values and interpolate + const float a = *(const float *)((const char *)src0->data + x0*nb00 + y0*nb01 + i02*nb02 + i03*nb03); + const float b = *(const float *)((const char *)src0->data + x1*nb00 + y0*nb01 + i02*nb02 + i03*nb03); + const float c = *(const float *)((const char *)src0->data + x0*nb00 + y1*nb01 + i02*nb02 + i03*nb03); + const float d = *(const float *)((const char *)src0->data + x1*nb00 + y1*nb01 + i02*nb02 + i03*nb03); + + const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy; + + float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + *y_dst = val; + } } } } + } else { + GGML_ABORT("unsupported upscale mode"); } } @@ -6721,8 +6941,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot; ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float; - GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type"); - GGML_ASSERT(v_to_float && "fattn: unsupported V-type"); + GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type"); + GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type"); // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { @@ -6818,10 +7038,14 @@ static void ggml_compute_forward_flash_attn_ext_f16( vs = expf(s - M); } - v_to_float(v_data, V32, DV); - // V += v*expf(s - M) - ggml_vec_mad_f32(DV, VKQ32, V32, vs); + if (v_to_float) { + v_to_float(v_data, V32, DV); + ggml_vec_mad_f32(DV, VKQ32, V32, vs); + } else { + // V is F32 + ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs); + } } S = S*ms + vs; // scale and increment sum with partial sum @@ -8264,152 +8488,6 @@ void ggml_compute_forward_rwkv_wkv7( } } -// ggml_compute_forward_map_unary - -static void ggml_compute_forward_map_unary_f32( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_unary_op_f32_t fun) { - - const ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - fun(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -void ggml_compute_forward_map_unary( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_unary_op_f32_t fun) { - - const ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_map_unary_f32(params, dst, fun); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_map_binary - -static void ggml_compute_forward_map_binary_f32( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_binary_op_f32_t fun) { - - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(src1)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - fun(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1])), - (float *) ((char *) src1->data + i*(src1->nb[1]))); - } -} - -void ggml_compute_forward_map_binary( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_binary_op_f32_t fun) { - - const ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_map_binary_f32(params, dst, fun); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_map_custom1 - -void ggml_compute_forward_map_custom1_f32( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_custom1_op_f32_t fun) { - - const ggml_tensor * a = dst->src[0]; - - if (params->ith != 0) { - return; - } - - fun(dst, a); -} - -// ggml_compute_forward_map_custom2 - -void ggml_compute_forward_map_custom2_f32( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_custom2_op_f32_t fun) { - - const ggml_tensor * a = dst->src[0]; - const ggml_tensor * b = dst->src[1]; - - if (params->ith != 0) { - return; - } - - fun(dst, a, b); -} - -// ggml_compute_forward_map_custom3 - -void ggml_compute_forward_map_custom3_f32( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_custom3_op_f32_t fun) { - - const ggml_tensor * a = dst->src[0]; - const ggml_tensor * b = dst->src[1]; - const ggml_tensor * c = dst->src[1]; - - if (params->ith != 0) { - return; - } - - fun(dst, a, b, c); -} - // ggml_compute_forward_map_custom1 void ggml_compute_forward_map_custom1( @@ -8455,6 +8533,18 @@ void ggml_compute_forward_map_custom3( p.fun(dst, a, b, c, params->ith, params->nth, p.userdata); } +// ggml_compute_forward_custom + +void ggml_compute_forward_custom( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + struct ggml_custom_op_params p; + memcpy(&p, dst->op_params, sizeof(p)); + + p.fun(dst, params->ith, params->nth, p.userdata); +} + // ggml_compute_forward_cross_entropy_loss static void ggml_compute_forward_cross_entropy_loss_f32( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index d43fbc1fc47..dc081b9e663 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -65,6 +65,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_pool_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_pool_2d_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); @@ -96,29 +97,10 @@ void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); -void ggml_compute_forward_map_unary( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_unary_op_f32_t fun); -void ggml_compute_forward_map_binary( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_binary_op_f32_t fun); -void ggml_compute_forward_map_custom1_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_custom1_op_f32_t fun); -void ggml_compute_forward_map_custom2_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_custom2_op_f32_t fun); -void ggml_compute_forward_map_custom3_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_custom3_op_f32_t fun); void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_custom(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 28aaa1b7189..04d10cec266 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -71,7 +71,7 @@ #define GGML_F16x8 float16x8_t #define GGML_F16x8_ZERO vdupq_n_f16(0.0f) #define GGML_F16x8_SET1(x) vdupq_n_f16(x) - #define GGML_F16x8_LOAD(x) vld1q_f16((const ggml_fp16_internal_t *)(x)) + #define GGML_F16x8_LOAD(x) vld1q_f16((const __fp16 *)(x)) #define GGML_F16x8_STORE vst1q_f16 #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) #define GGML_F16x8_ADD vaddq_f16 @@ -99,7 +99,7 @@ #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), (r)[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), (r)[i]) #define GGML_F16_VEC_FMA GGML_F16x8_FMA #define GGML_F16_VEC_ADD GGML_F16x8_ADD #define GGML_F16_VEC_MUL GGML_F16x8_MUL @@ -114,7 +114,7 @@ #define GGML_F32Cx4 float32x4_t #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x) - #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const ggml_fp16_internal_t *)(x))) + #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const __fp16 *)(x))) #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) #define GGML_F32Cx4_ADD vaddq_f32 @@ -125,7 +125,7 @@ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL @@ -392,7 +392,11 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { #define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \ vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \ vec_extract_fp32_from_shortl(vec_xl(0, p)) -#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i] +static inline unsigned char ggml_endian_byte(int i) { + uint16_t tmp_val = 1; + return ((unsigned char *)&tmp_val)[i]; +} +#define GGML_ENDIAN_BYTE(i) ggml_endian_byte(i) #define GGML_F16_VEC_STORE(p, r, i) \ if (i & 0x1) \ vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \ @@ -851,13 +855,17 @@ static inline __vector float __lzs_f16cx4_load(const ggml_fp16_t * x) { tmp[i] = GGML_FP16_TO_FP32(x[i]); } - return vec_xl(0, tmp); + // note: keep type-cast here to prevent compiler bugs + // see: https://github.com/ggml-org/llama.cpp/issues/12846 + return vec_xl(0, (const float *)(tmp)); } static inline void __lzs_f16cx4_store(ggml_fp16_t * x, __vector float y) { float arr[4]; - vec_xst(y, 0, arr); + // note: keep type-cast here to prevent compiler bugs + // see: https://github.com/ggml-org/llama.cpp/issues/12846 + vec_xst(y, 0, (float *)(arr)); for (int i = 0; i < 4; i++) { x[i] = GGML_FP32_TO_FP16(arr[i]); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a718b6a1288..8284a0017d2 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -729,7 +729,13 @@ struct ggml_cuda_graph { bool disable_due_to_failed_graph_capture = false; int number_consecutive_updates = 0; std::vector ggml_graph_properties; - std::vector updated_kernel_arg; + bool use_cpy_indirection = false; + std::vector cpy_dest_ptrs; + char ** dest_ptrs_d; + int dest_ptrs_size = 0; + // Index to allow each cpy kernel to be aware of it's position within the graph + // relative to other cpy nodes. + int graph_cpynode_index = -1; #endif }; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 2997e2b4d5b..a224ec0e12d 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -579,7 +579,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res const src_t * x = (const src_t *) vx; - y[i] = x[i]; + y[i] = float(x[i]); } template @@ -588,6 +588,17 @@ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict_ convert_unary<<>>(vx, y, k); } +to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_cuda; + case GGML_TYPE_F16: + return convert_unary_cuda; + default: + return nullptr; + } +} + to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: @@ -633,6 +644,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_BF16: + return convert_unary_cuda; default: return nullptr; } diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 5394be9f161..411a13cf126 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -7,7 +7,10 @@ using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, in typedef to_t_cuda_t to_fp32_cuda_t; typedef to_t_cuda_t to_fp16_cuda_t; +typedef to_t_cuda_t to_bf16_cuda_t; to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type); +to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type); + to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type); diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index cca2bee0b27..eca48052491 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -10,6 +10,13 @@ static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { *dsti = *xi; } +static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti; + + *dsti = *xi; +} + static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { const float * xi = (const float *) cxi; half * dsti = (half *) cdsti; @@ -32,16 +39,18 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) { } template -static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, +static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { + const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { return; } + char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; + // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor // then combine those indices with the corresponding byte offsets to get the total offsets const int64_t i03 = i/(ne00 * ne01 * ne02); @@ -288,16 +297,18 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { } template -static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, +static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { + const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } + char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; + const int i03 = i/(ne00 * ne01 * ne02); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; @@ -314,16 +325,18 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, } template -static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne, +static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { + const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } + char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; + const int i03 = i/(ne00 * ne01 * ne02); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; @@ -339,66 +352,97 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne, cpy_blck(cx + x_offset, cdst + dst_offset); } +// Copy destination pointers to GPU to be available when pointer indirection is in use + +void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) { +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) + if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers + CUDA_CHECK(cudaStreamSynchronize(stream)); + if (cuda_graph->dest_ptrs_d != nullptr) { + CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d)); + } + CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *))); + cuda_graph->dest_ptrs_size = host_dest_ptrs_size; + } + // copy destination pointers to GPU + CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream)); + cuda_graph->graph_cpynode_index = 0; // reset index +#else + GGML_UNUSED(cuda_graph); GGML_UNUSED(host_dest_ptrs); + GGML_UNUSED(host_dest_ptrs_size); GGML_UNUSED(stream); +#endif +} + static void ggml_cpy_f16_f32_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_f32_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); +} + +static void ggml_cpy_f32_bf16_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_f32_f16<<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_f16_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_q8_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK8_0 == 0); const int num_blocks = ne / QK8_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_q8_0_f32_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = ne; cpy_q_f32<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_q4_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK4_0 == 0); const int num_blocks = ne / QK4_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_q4_0_f32_cuda( @@ -407,22 +451,22 @@ static void ggml_cpy_q4_0_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream) { + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = ne; cpy_q_f32, QK4_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13); + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_q4_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK4_1 == 0); const int num_blocks = ne / QK4_1; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_q4_1_f32_cuda( @@ -431,22 +475,22 @@ static void ggml_cpy_q4_1_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream) { + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = ne; cpy_q_f32, QK4_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13); + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_q5_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK5_0 == 0); const int num_blocks = ne / QK5_0; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_q5_0_f32_cuda( @@ -455,22 +499,22 @@ static void ggml_cpy_q5_0_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream) { + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = ne; cpy_q_f32, QK5_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13); + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_q5_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK5_1 == 0); const int num_blocks = ne / QK5_1; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_q5_1_f32_cuda( @@ -479,35 +523,35 @@ static void ggml_cpy_q5_1_f32_cuda( const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, - cudaStream_t stream) { + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = ne; cpy_q_f32, QK5_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13); + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_iq4_nl_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK4_NL == 0); const int num_blocks = ne / QK4_NL; cpy_f32_q<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f16_f16_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } -void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) { +void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -541,51 +585,70 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char * src0_ddc = (char *) src0->data; char * src1_ddc = (char *) src1->data; + char ** dest_ptrs_d = nullptr; + int graph_cpynode_index = -1; +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) + if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { + dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d; + graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index; + } +#else + GGML_UNUSED(disable_indirection_for_this_node); +#endif if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { + ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { - ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { - ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { - ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { - ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, - nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { - ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { - ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { - ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); } +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) + if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { + ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index; + } +#endif + } void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - ggml_cuda_cpy(ctx, src0, dst); + bool disable_indirection = true; + ggml_cuda_cpy(ctx, src0, dst, disable_indirection); } void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { @@ -593,6 +656,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return nullptr; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { return (void*) cpy_f32_f16; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { + return (void*) cpy_f32_f16; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { return (void*) cpy_f32_f16; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh index 28b06cddaa8..0bd3c0c6f8c 100644 --- a/ggml/src/ggml-cuda/cpy.cuh +++ b/ggml/src/ggml-cuda/cpy.cuh @@ -2,8 +2,10 @@ #define CUDA_CPY_BLOCK_SIZE 64 -void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1); +void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false); void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1); + +void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 3fe22092f2c..56121705bdf 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -62,7 +62,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -102,7 +102,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -146,7 +146,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -193,7 +193,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -244,7 +244,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) { + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_0; diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 81290c90134..fcb6f848fe0 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -52,6 +52,18 @@ static __global__ void flash_attn_tile_ext_f32( return; #endif // FP16_MMA_AVAILABLE if (use_logit_softcap && !(D == 128 || D == 256)) { + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); + GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); + GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); + GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; return; } diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 7048748551f..d42ddca49f6 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -45,6 +45,18 @@ static __global__ void flash_attn_vec_ext_f32( // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); + GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); + GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); + GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; return; } @@ -114,7 +126,7 @@ static __global__ void flash_attn_vec_ext_f32( // Set memory to zero if out of bounds: if (ncols > 2 && ic0 + j >= ne01) { #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; tmp_q_i32[i] = 0; @@ -127,7 +139,7 @@ static __global__ void flash_attn_vec_ext_f32( const float * Q_f = (const float *) (Q + j*nb01); #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); } } @@ -140,7 +152,7 @@ static __global__ void flash_attn_vec_ext_f32( float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 8edc12649aa..7a2d1e45365 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -299,7 +299,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion; - const bool can_use_vector_kernel = (Q->ne[0] % (2*warp_size) == 0) && (prec == GGML_PREC_DEFAULT || Q->ne[0] <= 128); + const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0; if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { if (prec == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 861927654ec..e0e0d2137f3 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -96,31 +96,32 @@ int ggml_cuda_get_device() { static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { ggml_cuda_set_device(device); -#if defined(GGML_USE_HIP) && defined(GGML_HIP_UMA) - auto res = hipMallocManaged(ptr, size); - if (res == hipSuccess) { - // if error we "need" to know why... - CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device)); - } - return res; -#else - -#if !defined(GGML_USE_HIP) cudaError_t err; if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) { err = cudaMallocManaged(ptr, size); +#if defined(GGML_USE_HIP) + if (err == hipSuccess) { + CUDA_CHECK(cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device)); + } + + // fall back to cudaMalloc if not supported (e.g. on Windows) + if (err == hipErrorNotSupported) { + static bool warned_unsupported = false; + if (!warned_unsupported) { + GGML_LOG_WARN("hipMallocManaged unsupported, falling back to hipMalloc.\n"); + warned_unsupported = true; + } + + err = cudaMalloc(ptr, size); + } +#endif // defined(GGML_USE_HIP) } else { err = cudaMalloc(ptr, size); } return err; -#else - return cudaMalloc(ptr, size); -#endif // !defined(GGML_USE_HIP) - -#endif } #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) @@ -1194,7 +1195,35 @@ static void ggml_cuda_op_mul_mat_cublas( const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; - if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) { + if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { + ggml_cuda_pool_alloc src1_as_bf16(ctx.pool(id)); + if (src1->type != GGML_TYPE_BF16) { + const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type); + GGML_ASSERT(to_bf16_cuda != nullptr); + size_t ne = src1_ncols*ne10; + src1_as_bf16.alloc(ne); + to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), ne, stream); + } + const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get(); + const nv_bfloat16 * src0_ptr = (const nv_bfloat16 *)src0_dd_i; + ggml_cuda_pool_alloc dst_bf16(ctx.pool(id), row_diff*src1_ncols); + + const float alpha_f32 = 1.0f; + const float beta_f32 = 0.0f; + + CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha_f32, src0_ptr, CUDA_R_16BF, ne00, + src1_ptr, CUDA_R_16BF, ne10, + &beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16); + to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream); + } else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id)); if (src0->type != GGML_TYPE_F16) { @@ -1381,6 +1410,11 @@ static void ggml_cuda_op_mul_mat( const int64_t ne0 = dst->ne[0]; const int64_t ne1 = dst->ne[1]; + // const int64_t nb10 = src1->nb[0]; + const int64_t nb11 = src1->nb[1]; + const int64_t nb12 = src1->nb[2]; + const int64_t nb13 = src1->nb[3]; + const int64_t nb2 = dst->nb[2]; const int64_t nb3 = dst->nb[3]; @@ -1516,7 +1550,10 @@ static void ggml_cuda_op_mul_mat( dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size); if (src1_on_device && src1_is_contiguous) { - quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream); + quantize_src1( + dev[id].src1_ddf, dev[id].src1_ddq, src0->type, ne10, + nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float), + src1_padded_col_size, ne11, ne12, ne13, stream); CUDA_CHECK(cudaGetLastError()); } } @@ -1611,7 +1648,9 @@ static void ggml_cuda_op_mul_mat( } if (quantize_src1 && !src1_is_contiguous) { - quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream); + quantize_src1( + src1_ddf_i, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10, + src1_padded_col_size, src1_ncols, 1, 1, stream); CUDA_CHECK(cudaGetLastError()); } @@ -1849,7 +1888,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); - bool use_mul_mat_vec = (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) + bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) @@ -1890,10 +1929,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { + if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { // the custom F16 vector kernel can be used over batched cuBLAS GEMM // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) - ggml_cuda_mul_mat_vec(ctx, src0, src1, dst); + ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst); + } else if (!split && use_mul_mat_vec_q) { + ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst); } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // general KQ + KQV multi-batch without FlashAttention @@ -1970,6 +2011,15 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * GGML_TENSOR_BINARY_OP_LOCALS + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && ne2 == 1) { + if (ggml_is_quantized(src0->type)) { + ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); + } else { + ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst); + } + return; + } + GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers"); cudaStream_t stream = ctx.stream(); @@ -2006,97 +2056,75 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst_row.nb[2] = nb1; dst_row.nb[3] = nb1; - if (ne12 == 1) { - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + ggml_cuda_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); + ggml_cuda_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); - GGML_ASSERT(i02 >= 0 && i02 < n_as); + src1_row.data = src1_contiguous.get(); + dst_row.data = dst_contiguous.get(); - const int64_t i11 = id % ne11; - const int64_t i12 = iid1; + for (int64_t i02 = 0; i02 < n_as; i02++) { + int64_t num_src1_rows = 0; - const int64_t i1 = id; - const int64_t i2 = i12; - - src0_row.data = src0_original + i02*nb02; - src1_row.data = src1_original + i11*nb11 + i12*nb12; - dst_row.data = dst_original + i1*nb1 + i2*nb2; - - ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); - } - } - } else { - ggml_cuda_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); - ggml_cuda_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); - - src1_row.data = src1_contiguous.get(); - dst_row.data = dst_contiguous.get(); - - for (int64_t i02 = 0; i02 < n_as; i02++) { - int64_t num_src1_rows = 0; - - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - - GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - if (row_id_i != i02) { - continue; - } + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); - num_src1_rows++; + if (row_id_i != i02) { + continue; } - } - if (num_src1_rows == 0) { - continue; + num_src1_rows++; } + } - ggml_cuda_pool_alloc dev_cur_src1_row(ctx.pool(), 1); - ggml_cuda_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); - CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); + if (num_src1_rows == 0) { + continue; + } - { - dim3 block_dims(std::min((unsigned int)ne10, 768u)); - dim3 grid_dims(ids->ne[1], n_ids); - k_copy_src1_to_contiguous<<>>( - src1_original, src1_contiguous.get(), - dev_cur_src1_row.get(), dev_row_mapping.get(), - ids_dev, i02, ids->nb[1], ids->nb[0], - ne11, ne10, - nb11, nb12); - CUDA_CHECK(cudaGetLastError()); - } + ggml_cuda_pool_alloc dev_cur_src1_row(ctx.pool(), 1); + ggml_cuda_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); + CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); - src0_row.data = src0_original + i02*nb02; + { + dim3 block_dims(std::min((unsigned int)ne10, 768u)); + dim3 grid_dims(ids->ne[1], n_ids); + k_copy_src1_to_contiguous<<>>( + src1_original, src1_contiguous.get(), + dev_cur_src1_row.get(), dev_row_mapping.get(), + ids_dev, i02, ids->nb[1], ids->nb[0], + ne11, ne10, + nb11, nb12); + CUDA_CHECK(cudaGetLastError()); + } - GGML_ASSERT(nb11 == sizeof(float)*ne10); - GGML_ASSERT(nb1 == sizeof(float)*ne0); + src0_row.data = src0_original + i02*nb02; - src1_row.ne[1] = num_src1_rows; - src1_row.nb[1] = nb11; - src1_row.nb[2] = num_src1_rows*nb11; - src1_row.nb[3] = num_src1_rows*nb11; + GGML_ASSERT(nb11 == sizeof(float)*ne10); + GGML_ASSERT(nb1 == sizeof(float)*ne0); - dst_row.ne[1] = num_src1_rows; - dst_row.nb[1] = nb1; - dst_row.nb[2] = num_src1_rows*nb1; - dst_row.nb[3] = num_src1_rows*nb1; + src1_row.ne[1] = num_src1_rows; + src1_row.nb[1] = nb11; + src1_row.nb[2] = num_src1_rows*nb11; + src1_row.nb[3] = num_src1_rows*nb11; - ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + dst_row.ne[1] = num_src1_rows; + dst_row.nb[1] = nb1; + dst_row.nb[2] = num_src1_rows*nb1; + dst_row.nb[3] = num_src1_rows*nb1; - { - dim3 block_dims(std::min((unsigned int)ne0, 768u)); - dim3 grid_dims(num_src1_rows); - k_copy_dst_from_contiguous<<>>( - dst_original, dst_contiguous.get(), - dev_row_mapping.get(), - ne0, - nb1, nb2); - CUDA_CHECK(cudaGetLastError()); - } + ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + + { + dim3 block_dims(std::min((unsigned int)ne0, 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_dst_from_contiguous<<>>( + dst_original, dst_contiguous.get(), + dev_row_mapping.get(), + ne0, + nb1, nb2); + CUDA_CHECK(cudaGetLastError()); } } } @@ -2441,10 +2469,11 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { #ifdef USE_CUDA_GRAPH static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, - std::vector & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) { + bool use_cuda_graph) { // Loop over nodes in GGML graph to obtain info needed for CUDA graph - cuda_ctx->cuda_graph->updated_kernel_arg.clear(); + cuda_ctx->cuda_graph->cpy_dest_ptrs.clear(); + for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2459,10 +2488,10 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud #endif } - if (node->op == GGML_OP_MUL_MAT_ID) { + if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) { use_cuda_graph = false; // This node type is not supported by CUDA graph capture #ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__); + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); #endif } @@ -2476,8 +2505,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud } if (node->op == GGML_OP_CPY) { - // store the copy op parameter which changes with each token. - cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data)); + + // Store the pointers which are updated for each token, such that these can be sent + // to the device and accessed using indirection from CUDA graph + cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data); + // store a pointer to each copy op CUDA kernel to identify it later void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); if (!ptr) { @@ -2485,10 +2517,6 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud #ifndef NDEBUG GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__); #endif - } else { - if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) { - ggml_cuda_cpy_fn_ptrs.push_back(ptr); - } } } @@ -2497,6 +2525,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud } } + if (use_cuda_graph) { + cuda_ctx->cuda_graph->use_cpy_indirection = true; + // copy pointers to GPU so they can be accessed via indirection within CUDA graph + ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream()); + } + return use_cuda_graph; } @@ -2551,51 +2585,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra return true; } -static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) { - - if (cuda_graph_update_required) { - // Extract nodes from graph - // First call with null argument gets number of nodes in graph - CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes)); - // Subsequent call with non-null argument gets nodes - cuda_ctx->cuda_graph->nodes.clear(); - cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes); - cuda_ctx->cuda_graph->params.clear(); - cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes); - if (cuda_ctx->cuda_graph->num_nodes > 0) { - CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes)); - - // Loop over nodes, and extract kernel parameters from each node - for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { - cudaGraphNodeType node_type; - CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type)); - if (node_type == cudaGraphNodeTypeKernel) { - cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime - if (stat == cudaErrorInvalidDeviceFunction) { - // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. - // We don't need to update blas nodes, so clear error and move on. - (void)cudaGetLastError(); - } else { - GGML_ASSERT(stat == cudaSuccess); - } - } - } - } - } else { - // One of the arguments to the copy kernel is updated for each token, hence we need to - // replace that argument with the updated value in the CUDA graph - // on update steps, the live parameters will already be captured - int k = 0; - for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { - if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) { - char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++); - *(void**)cuda_ctx->cuda_graph->params[i].kernelParams[1] = *(void**)updated_kernel_arg_ptr; - CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i])); - } - } - } -} - static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { bool cuda_graph_update_required = false; @@ -2655,8 +2644,7 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { #endif static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, - [[maybe_unused]] std::vector & ggml_cuda_cpy_fn_ptrs, bool & graph_evaluated_or_captured, bool & use_cuda_graph, - bool & cuda_graph_update_required) { + bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { while (!graph_evaluated_or_captured) { // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. @@ -2706,13 +2694,9 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); } - - // Perform update to graph (if required for this token), and change copy parameter (required for every token) - maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required); - - // Update graph executable - update_cuda_graph_executable(cuda_ctx); - + if (cuda_graph_update_required) { // Update graph executable + update_cuda_graph_executable(cuda_ctx); + } // Launch graph CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); #else @@ -2726,10 +2710,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cuda_set_device(cuda_ctx->device); - // vector of pointers to CUDA cpy kernels, which are required to identify - // kernel parameters which need updated in the graph for each token - std::vector ggml_cuda_cpy_fn_ptrs; - #ifdef USE_CUDA_GRAPH static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); @@ -2763,8 +2743,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, if (use_cuda_graph) { cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); - use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, - ggml_cuda_cpy_fn_ptrs, use_cuda_graph); + use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph); // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. if (use_cuda_graph && cuda_graph_update_required) { @@ -2785,6 +2764,10 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); } + if (!use_cuda_graph) { + cuda_ctx->cuda_graph->use_cpy_indirection = false; + } + #else bool use_cuda_graph = false; bool cuda_graph_update_required = false; @@ -2792,7 +2775,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, bool graph_evaluated_or_captured = false; - evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); + evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); return GGML_STATUS_SUCCESS; } @@ -3096,6 +3079,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { return true; } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { return true; } @@ -3216,9 +3202,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: { - const size_t ts = ggml_type_size(op->src[0]->type); - const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2]; - return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts; + return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]); } case GGML_OP_IM2COL: case GGML_OP_POOL_2D: @@ -3230,6 +3214,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_GROUP_NORM: return ggml_is_contiguous(op->src[0]); case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_PAD: case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: @@ -3249,6 +3234,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (op->src[0]->ne[0] == 192) { return false; } + if (op->src[0]->ne[0] == 576) { + // DeepSeek MLA + return false; + } if (op->src[0]->ne[3] != 1) { return false; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 532358018f4..3cb2015520b 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -155,25 +155,27 @@ static constexpr __device__ int get_mmq_y_device() { #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { - return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : - type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : - type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 : - type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : - type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : - type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : - type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : - type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : - type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 : - type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 : - type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 : - tile_x_sizes{0, 0, 0}; + switch (type) { + case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0; + case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1; + case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K; + case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K; + case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K; + case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K; + case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K; + case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0; + default: return tile_x_sizes{0, 0, 0}; + } } #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) @@ -189,25 +191,27 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : - type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : - type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : - type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K : - type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K : - type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 : - 0; + switch (type) { + case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; + case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K; + case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0; + default: return 0; + } } #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu index b39961cd115..d8c385e2399 100644 --- a/ggml/src/ggml-cuda/mmv.cu +++ b/ggml/src/ggml-cuda/mmv.cu @@ -4,18 +4,23 @@ template static __global__ void mul_mat_vec( - const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, + const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, + const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row, const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) { - const int64_t row = blockIdx.x; - const int64_t channel = blockIdx.y; - const int64_t sample = blockIdx.z; - const int tid = threadIdx.x; - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - x += (sample/sample_ratio)*stride_sample_x + (channel/channel_ratio)*stride_channel_x + row*stride_row; - y += sample *stride_sample_y + channel *stride_channel_y; - dst += sample *stride_sample_dst + channel *stride_channel_dst; + const int64_t row = blockIdx.x; + const int64_t channel_dst = blockIdx.y; + const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio; + const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst; + const int64_t sample_dst = blockIdx.z; + const int64_t sample_x = sample_dst / sample_ratio; + const int64_t sample_y = sample_dst; + const int tid = threadIdx.x; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row; + y += sample_y *stride_sample_y + channel_y *stride_channel_y; + dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst; const float2 * y2 = (const float2 *) y; @@ -31,12 +36,19 @@ static __global__ void mul_mat_vec( float sumf = 0.0f; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { + const float2 * x2 = (const float2 *) x; + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmpx = x2[col2]; + const float2 tmpy = y2[col2]; + sumf += tmpx.x*tmpy.x; + sumf += tmpx.y*tmpy.y; + } + } else if constexpr (std::is_same::value) { const half2 * x2 = (const half2 *) x; if (std::is_same::value) { - sumf = 0.0f; - for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { const float2 tmpx = __half22float2(x2[col2]); const float2 tmpy = y2[col2]; @@ -59,8 +71,6 @@ static __global__ void mul_mat_vec( } } else if constexpr (std::is_same::value) { const int * x2 = (const int *) x; - sumf = 0.0f; - for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { const int tmpx = x2[col2]; const float2 tmpy = y2[col2]; @@ -92,17 +102,17 @@ static __global__ void mul_mat_vec( template static void launch_mul_mat_vec_cuda( - const T * x, const float * y, float * dst, - const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(stride_row % 2 == 0); - GGML_ASSERT(nchannels_y % nchannels_x == 0); - GGML_ASSERT(nsamples_y % nsamples_x == 0); - const int64_t channel_ratio = nchannels_y / nchannels_x; - const int64_t sample_ratio = nsamples_y / nsamples_x; + GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); + GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const int64_t channel_ratio = nchannels_dst / nchannels_x; + const int64_t sample_ratio = nsamples_dst / nsamples_x; int device; int warp_size; @@ -124,48 +134,48 @@ static void launch_mul_mat_vec_cuda( } const int smem = warp_size*sizeof(float); - const dim3 block_nums(nrows, nchannels_y, nsamples_y); + const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { case 32: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 64: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 96: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 128: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 160: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 192: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 224: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 256: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; default: { GGML_ABORT("fatal error"); @@ -175,28 +185,28 @@ static void launch_mul_mat_vec_cuda( template static void mul_mat_vec_cuda( - const T * x, const float * y, float * dst, - const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, enum ggml_prec prec, cudaStream_t stream) { - switch (prec) { - case GGML_PREC_DEFAULT: { + if constexpr(std::is_same::value) { + if (prec == GGML_PREC_DEFAULT) { launch_mul_mat_vec_cuda - (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case GGML_PREC_F32: { - launch_mul_mat_vec_cuda - (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; + (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + return; + } } + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } -void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); +void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { + GGML_ASSERT( src1->type == GGML_TYPE_F32); + GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_TENSOR_BINARY_OP_LOCALS; @@ -204,21 +214,24 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * const size_t ts_src1 = ggml_type_size(src1->type); const size_t ts_dst = ggml_type_size(dst->type); - GGML_ASSERT(ne11 == 1); - GGML_ASSERT(ne12 == ne2); + GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. GGML_ASSERT(ne13 == ne3); - GGML_ASSERT(nb00 == ts_src0); - GGML_ASSERT(nb10 == ts_src1); - GGML_ASSERT(nb0 == ts_dst); + GGML_ASSERT( nb00 == ts_src0); + GGML_ASSERT( nb10 == ts_src1); + GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); + GGML_ASSERT( nb0 == ts_dst); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; - const float * src1_d = (const float *) src1->data; - float * dst_d = (float *) dst->data; + const float * src1_d = (const float *) src1->data; + const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; + float * dst_d = (float *) dst->data; const int64_t s01 = src0->nb[1] / ts_src0; + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s1 = dst->nb[1] / ts_dst; const int64_t s02 = src0->nb[2] / ts_src0; const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s2 = dst->nb[2] / ts_dst; @@ -226,14 +239,33 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * const int64_t s13 = src1->nb[3] / ts_src1; const int64_t s3 = dst->nb[3] / ts_dst; + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: + const int64_t ncols_dst = ids ? ne2 : ne1; + const int64_t nchannels_y = ids ? ne11 : ne12; + const int64_t nchannels_dst = ids ? ne1 : ne2; + const int64_t stride_channel_dst = ids ? s1 : s2; + const int64_t stride_channel_y = ids ? s11 : s12; + + GGML_ASSERT(ncols_dst == 1); + switch (src0->type) { + case GGML_TYPE_F32: { + const float * src0_d = (const float *) src0->data; + mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01, + ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, + ne03, ne3, s03, s13, s3, prec, ctx.stream()); + } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream()); + mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01, + ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, + ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream()); + mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01, + ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, + ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); @@ -262,27 +294,34 @@ void ggml_cuda_op_mul_mat_vec( const int64_t stride_row = ne00; const int64_t nchannels_x = 1; const int64_t nchannels_y = 1; + const int64_t nchannels_dst = 1; const int64_t stride_channel_x = 0; const int64_t stride_channel_y = 0; const int64_t stride_channel_dst = 0; const int64_t nsamples_x = 1; - const int64_t nsamples_y = 1; + const int64_t nsamples_dst = 1; const int64_t stride_sample_x = 0; const int64_t stride_sample_y = 0; const int64_t stride_sample_dst = 0; switch (src0->type) { + case GGML_TYPE_F32: { + const float * src0_d = (const float *) src0_dd_i; + mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, - nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, - nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); diff --git a/ggml/src/ggml-cuda/mmv.cuh b/ggml/src/ggml-cuda/mmv.cuh index 78a1cd4a690..756e7e1cc7f 100644 --- a/ggml/src/ggml-cuda/mmv.cuh +++ b/ggml/src/ggml-cuda/mmv.cuh @@ -3,7 +3,7 @@ // maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available #define MMV_MAX_ROWS 512 -void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); +void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); void ggml_cuda_op_mul_mat_vec( ggml_backend_cuda_context & ctx, diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index eef8585a738..d846e35a6a2 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -1,50 +1,57 @@ #include "mmvq.cuh" +#include "quantize.cuh" #include "vecdotq.cuh" +#include + typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 : - type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 : - type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 : - type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 : - type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 : - type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 : - type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 : - type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 : - type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 : - type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 : - type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 : - type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 : - type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 : - type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 : - type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 : - type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 : - type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : - type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : - type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : - nullptr; + switch (type) { + case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1; + case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1; + case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1; + case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1; + case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1; + case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1; + case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1; + case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1; + case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1; + case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1; + case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1; + case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1; + case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1; + case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1; + case GGML_TYPE_IQ1_S: return vec_dot_iq1_s_q8_1; + case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1; + case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1; + case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1; + case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1; + default: return nullptr; + } } static constexpr __device__ int get_vdr_mmvq(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ : - type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ : - type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ : - type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ : - type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ : - type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ : - type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ : - type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ : - type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ : - type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ : - type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ : - type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ : - type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ : - type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ : - type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ : - type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ : - type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ : - 1; + switch (type) { + case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; + case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; + case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ; + case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ; + case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ; + case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ; + case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ; + case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ; + case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ; + case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ; + case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ; + case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ; + case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ; + default: return 1; + } } enum mmvq_parameter_table_id { @@ -73,9 +80,9 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { return MMVQ_PARAMETERS_GENERIC; } -static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_parameter_table_id table_id) { +static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) { if (table_id == MMVQ_PARAMETERS_GENERIC) { - switch (ncols_y) { + switch (ncols_dst) { case 1: case 2: case 3: @@ -90,7 +97,7 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_paramete return 1; } } else if (table_id == MMVQ_PARAMETERS_GCN) { - switch (ncols_y) { + switch (ncols_dst) { case 1: case 2: case 3: @@ -107,9 +114,9 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_paramete return 1; } -static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int table_id) { +static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) { if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { - switch (ncols_y) { + switch (ncols_dst) { case 1: return 1; case 2: @@ -127,19 +134,21 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int ta return 1; } -template +template // tell the compiler to use as many registers as it wants, see nwarps definition below -__launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) +__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst, + const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { constexpr int qk = ggml_cuda_type_traits::qk; constexpr int qi = ggml_cuda_type_traits::qi; constexpr int vdr = get_vdr_mmvq(type); constexpr mmvq_parameter_table_id table_id = get_device_table_id(); - constexpr int nwarps = calc_nwarps(ncols_y, table_id); - constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id); + constexpr int nwarps = calc_nwarps(ncols_dst, table_id); + constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); @@ -147,13 +156,21 @@ static __global__ void mul_mat_vec_q( const int tid = warp_size*threadIdx.y + threadIdx.x; const int row0 = rows_per_cuda_block*blockIdx.x; const int blocks_per_row_x = ncols_x / qk; - const int blocks_per_col_y = nrows_y / QK8_1; constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; + // The MUL_MAT_ID code path with ids != nullptr is only implemetned for ncols_dst == 1. + const int channel_dst = blockIdx.y; + const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio; + const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst; + const int sample_dst = blockIdx.z; + const int sample_x = sample_dst / sample_ratio; + const int sample_y = sample_dst; + // partial sum for each thread - float tmp[ncols_y][rows_per_cuda_block] = {{0.0f}}; + float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; - const block_q8_1 * y = (const block_q8_1 *) vy; + const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; + const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx @@ -162,18 +179,19 @@ static __global__ void mul_mat_vec_q( const int kqs = vdr * (tid % (qi/vdr)); #pragma unroll - for (int j = 0; j < ncols_y; ++j) { + for (int j = 0; j < ncols_dst; ++j) { #pragma unroll for (int i = 0; i < rows_per_cuda_block; ++i) { - tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs); + tmp[j][i] += vec_dot_q_cuda( + vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); } } } - __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size]; + __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; if (threadIdx.y > 0) { #pragma unroll - for (int j = 0; j < ncols_y; ++j) { + for (int j = 0; j < ncols_dst; ++j) { #pragma unroll for (int i = 0; i < rows_per_cuda_block; ++i) { tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i]; @@ -185,9 +203,11 @@ static __global__ void mul_mat_vec_q( return; } + dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; + // sum up partial sums and write back result #pragma unroll - for (int j = 0; j < ncols_y; ++j) { + for (int j = 0; j < ncols_dst; ++j) { #pragma unroll for (int i = 0; i < rows_per_cuda_block; ++i) { #pragma unroll @@ -197,88 +217,121 @@ static __global__ void mul_mat_vec_q( tmp[j][i] = warp_reduce_sum(tmp[j][i]); } - if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < (unsigned)nrows_dst)) { - dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x]; + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + int(threadIdx.x) < stride_col_dst)) { + dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x]; } } - - GGML_UNUSED(nrows_x); } -static std::pair calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) { - const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id); - const dim3 block_nums(nblocks, 1, 1); - const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1); +static std::pair calc_launch_params( + const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y, + const int warp_size, const mmvq_parameter_table_id table_id) { + const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); + const dim3 block_nums(nblocks, nchannels_y, nsamples_y); + const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1); return {block_nums, block_dims}; } template -static void mul_mat_vec_q_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { +static void mul_mat_vec_q_switch_ncols_dst( + const void * vx, const void * vy, const int32_t * ids, float * dst, + const int ncols_x, const int nrows_x, const int ncols_dst, + const int stride_row_x, const int stride_col_y, const int stride_col_dst, + const int nchannels_x, const int nchannels_y, const int nchannels_dst, + const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + cudaStream_t stream) { GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); - GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); + GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); + + const int channel_ratio = nchannels_dst / nchannels_x; + const int sample_ratio = nsamples_dst / nsamples_x; const int device = ggml_cuda_get_device(); const int warp_size = ggml_cuda_info().devices[device].warp_size; const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); - switch (ncols_y) { + GGML_ASSERT(!ids || ncols_dst == 1); + switch (ncols_dst) { case 1: { - constexpr int c_ncols_y = 1; - std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + constexpr int c_ncols_dst = 1; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q<<>> + (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); break; } case 2: { - constexpr int c_ncols_y = 2; - std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + constexpr int c_ncols_dst = 2; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q<<>> + (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); break; } case 3: { - constexpr int c_ncols_y = 3; - std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + constexpr int c_ncols_dst = 3; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q<<>> + (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); break; } case 4: { - constexpr int c_ncols_y = 4; - std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + constexpr int c_ncols_dst = 4; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q<<>> + (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); break; } case 5: { - constexpr int c_ncols_y = 5; - std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + constexpr int c_ncols_dst = 5; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q<<>> + (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); break; } case 6: { - constexpr int c_ncols_y = 6; - std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + constexpr int c_ncols_dst = 6; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q<<>> + (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); break; } case 7: { - constexpr int c_ncols_y = 7; - std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + constexpr int c_ncols_dst = 7; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q<<>> + (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); break; } case 8: { - constexpr int c_ncols_y = 8; - std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + constexpr int c_ncols_dst = 8; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q<<>> + (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); break; } default: @@ -287,221 +340,241 @@ static void mul_mat_vec_q_cuda( } } -static void mul_mat_vec_q4_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_q4_1_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_q5_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_q5_1_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_q8_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_q2_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_q3_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_q4_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_q5_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_q6_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq2_xxs_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq2_xs_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq2_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq3_xxs_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq1_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq1_m_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq4_nl_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq4_xs_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq3_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -void ggml_cuda_op_mul_mat_vec_q( - ggml_backend_cuda_context & ctx, - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, - const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, cudaStream_t stream) { - - const int64_t ne00 = src0->ne[0]; - const int64_t row_diff = row_high - row_low; - - const int64_t ne10 = src1->ne[0]; - GGML_ASSERT(ne10 % QK8_1 == 0); - - const int64_t ne0 = dst->ne[0]; - - int id = ggml_cuda_get_device(); - - // the main device has a larger memory buffer to hold the results from all GPUs - // nrows_dst == nrows of the matrix that the kernel writes into - const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - - switch (src0->type) { +static void mul_mat_vec_q_switch_type( + const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, float * dst, + const int ncols_x, const int nrows_x, const int ncols_dst, + const int stride_row_x, const int stride_col_y, const int stride_col_dst, + const int nchannels_x, const int nchannels_y, const int nchannels_dst, + const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + cudaStream_t stream) { + switch (type_x) { case GGML_TYPE_Q4_0: - mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_Q6_K: - mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_IQ2_XXS: - mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_IQ2_XS: - mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_IQ2_S: - mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_IQ3_XXS: - mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_IQ1_S: - mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_IQ1_M: - mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_IQ4_NL: - mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_IQ4_XS: - mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; case GGML_TYPE_IQ3_S: - mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); break; default: GGML_ABORT("fatal error"); break; } +} + +void ggml_cuda_mul_mat_vec_q( + ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { + GGML_ASSERT( src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID. + + GGML_TENSOR_BINARY_OP_LOCALS; + + cudaStream_t stream = ctx.stream(); + + const size_t ts_src0 = ggml_type_size(src0->type); + const size_t ts_src1 = ggml_type_size(src1->type); + const size_t ts_dst = ggml_type_size(dst->type); + + GGML_ASSERT( nb00 == ts_src0); + GGML_ASSERT( nb10 == ts_src1); + GGML_ASSERT( nb0 == ts_dst); + GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); + + GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. + + const float * src1_d = (const float *) src1->data; + const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; + float * dst_d = (float *) dst->data; + + const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING); + ggml_cuda_pool_alloc src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1); + { + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s13 = src1->nb[3] / ts_src1; + quantize_row_q8_1_cuda(src1_d, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); + } + + const int64_t s01 = src0->nb[1] / ts_src0; + const int64_t s11 = ne10_padded / QK8_1; + const int64_t s1 = dst->nb[1] / ts_dst; + const int64_t s02 = src0->nb[2] / ts_src0; + const int64_t s2 = dst->nb[2] / ts_dst; + const int64_t s03 = src0->nb[3] / ts_src0; + const int64_t s3 = dst->nb[3] / ts_dst; + + const int64_t s12 = ne11*s11; + const int64_t s13 = ne12*s12; + + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: + const int64_t ncols_dst = ids ? ne2 : ne1; + const int64_t nchannels_y = ids ? ne11 : ne12; + const int64_t nchannels_dst = ids ? ne1 : ne2; + const int64_t stride_col_dst = ids ? s2 : s1; + const int64_t stride_col_y = ids ? s12 : s11; + const int64_t stride_channel_dst = ids ? s1 : s2; + const int64_t stride_channel_y = ids ? s11 : s12; + + mul_mat_vec_q_switch_type( + src0->data, src0->type, src1_q8_1.get(), ids_d, dst_d, ne00, + ne01, ncols_dst, s01, stride_col_y, stride_col_dst, + ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, + ne03, ne3, s03, s13, s3, stream); +} + +void ggml_cuda_op_mul_mat_vec_q( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + const int64_t ne10 = src1->ne[0]; + GGML_ASSERT(ne10 % QK8_1 == 0); + + const int64_t ne0 = dst->ne[0]; + + int id = ggml_cuda_get_device(); + + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the kernel writes into + const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; + + const int stride_row_x = ne00 / ggml_blck_size(src0->type); + const int stride_col_y = src1_padded_row_size / QK8_1; + + mul_mat_vec_q_switch_type( + src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream); GGML_UNUSED(src1); GGML_UNUSED(dst); diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index d9e42fdd6d1..39dc7d33eb5 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -2,6 +2,9 @@ #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. +void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); + void ggml_cuda_op_mul_mat_vec_q( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 1702e4ce2fe..3bab47d56a2 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -1,30 +1,40 @@ #include "quantize.cuh" #include -static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) { - const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void quantize_q8_1( + const float * __restrict__ x, void * __restrict__ vy, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int ne1, const int ne2) { + const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; - if (ix0 >= kx0_padded) { + if (i0 >= ne0) { return; } - const int64_t ix1 = blockIdx.y; + const int64_t i1 = blockIdx.y; + const int64_t i2 = blockIdx.z % ne2; + const int64_t i3 = blockIdx.z / ne2; + + const int64_t & i00 = i0; + const int64_t & i01 = i1; + const int64_t & i02 = i2; + const int64_t & i03 = i3; - const int64_t i_padded = ix1*kx0_padded + ix0; + const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0; block_q8_1 * y = (block_q8_1 *) vy; - const int64_t ib = i_padded / QK8_1; // block index - const int64_t iqs = i_padded % QK8_1; // quant index + const int64_t ib = i_cont / QK8_1; // block index + const int64_t iqs = i_cont % QK8_1; // quant index - const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f; + const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f; float amax = fabsf(xi); float sum = xi; amax = warp_reduce_max(amax); - sum = warp_reduce_sum(sum); + sum = warp_reduce_sum(sum); - const float d = amax / 127; + const float d = amax / 127; const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); y[ib].qs[iqs] = q; @@ -127,43 +137,45 @@ static __global__ void quantize_mmq_q8_1( } void quantize_row_q8_1_cuda( - const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, - const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) { + const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { - GGML_ASSERT(kx0_padded % QK8_1 == 0); + GGML_ASSERT(ne0 % QK8_1 == 0); - const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; - const dim3 num_blocks(block_num_x, kx1*channels, 1); + const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + const dim3 num_blocks(block_num_x, ne1, ne2*ne3); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1<<>>(x, vy, kx0, kx0_padded); - - GGML_UNUSED(type_x); + quantize_q8_1<<>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + GGML_UNUSED(type_src0); } void quantize_mmq_q8_1_cuda( - const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, - const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) { + const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { - GGML_ASSERT(kx0_padded % (4*QK8_1) == 0); + GGML_ASSERT(ne0 % (4*QK8_1) == 0); - const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); - const dim3 num_blocks(block_num_x, kx1, channels); + const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); + const dim3 num_blocks(block_num_x, ne1, ne2*ne3); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); - switch (mmq_get_q8_1_ds_layout(type_x)) { + switch (mmq_get_q8_1_ds_layout(type_src0)) { case MMQ_Q8_1_DS_LAYOUT_D4: quantize_mmq_q8_1 - <<>>(x, vy, kx0, kx1, kx0_padded); + <<>>(x, vy, ne00, ne1, ne0); break; case MMQ_Q8_1_DS_LAYOUT_DS4: quantize_mmq_q8_1 - <<>>(x, vy, kx0, kx1, kx0_padded); + <<>>(x, vy, ne00, ne1, ne0); break; case MMQ_Q8_1_DS_LAYOUT_D2S6: quantize_mmq_q8_1 - <<>>(x, vy, kx0, kx1, kx0_padded); + <<>>(x, vy, ne00, ne1, ne0); break; default: GGML_ABORT("fatal error"); break; } + GGML_UNUSED(s01); + GGML_UNUSED(s02); + GGML_UNUSED(s03); } diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh index 03bf322b958..b627c4e4008 100644 --- a/ggml/src/ggml-cuda/quantize.cuh +++ b/ggml/src/ggml-cuda/quantize.cuh @@ -12,13 +12,13 @@ static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access."); typedef void (*quantize_cuda_t)( - const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, - const ggml_type type_x, cudaStream_t stream); + const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream); void quantize_row_q8_1_cuda( - const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, - const ggml_type type_x, cudaStream_t stream); + const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream); void quantize_mmq_q8_1_cuda( - const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, - const ggml_type type_x, cudaStream_t stream); + const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index cfe03d68ff0..f6375719637 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -4,13 +4,14 @@ template static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, - const int nc, const int ncs, const int nr, const int n_t, const int n_s) { + const int64_t n_t) { + GGML_UNUSED(src0_nb0); const int tid = threadIdx.x; const int bidx = blockIdx.x; const int bidy = blockIdx.y; - const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1); - const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1); + const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1); + const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1); float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0); const int stride_x = src0_nb1 / sizeof(float); @@ -21,15 +22,15 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float float w[d_conv] = { 0.0f }; #pragma unroll - for (int j = 0; j < d_conv; j++) { + for (size_t j = 0; j < d_conv; j++) { w[j] = w_block[tid * stride_w + j]; } - for (int i = 0; i < n_t; i++) { + for (int64_t i = 0; i < n_t; i++) { float sumf = 0.0f; if (i == 0) { - for (int j = 0; j < d_conv; j++) { + for (size_t j = 0; j < d_conv; j++) { x[j] = x_block[tid * stride_x + j]; } } else { @@ -37,27 +38,26 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float } #pragma unroll - for (int j = 0; j < d_conv; j++) { + for (size_t j = 0; j < d_conv; j++) { sumf += x[(i + j) % d_conv] * w[j]; } y_block[i * stride_y + tid] = sumf; } } -template +template static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, - const int dst_nb1, const int dst_nb2, const int nc, const int ncs, - const int nr, const int n_t, const int n_s) { + const int dst_nb1, const int dst_nb2, const int64_t n_t) { const int tid = threadIdx.x; const int bidx = blockIdx.x; const int bidy = blockIdx.y; const int bidz = blockIdx.z; - const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 + + const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 + bidz * split_n_t * src0_nb0); - const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1); + const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1); float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0); @@ -69,17 +69,17 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, float w[d_conv] = { 0.0f }; #pragma unroll - for (int j = 0; j < d_conv; j++) { + for (size_t j = 0; j < d_conv; j++) { w[j] = w_block[tid * stride_w + j]; } #pragma unroll - for (int i = 0; i < split_n_t; i++) { + for (int64_t i = 0; i < split_n_t; i++) { if (bidz * split_n_t + i < n_t) { float sumf = 0.0f; if (i == 0) { - for (int j = 0; j < d_conv; j++) { + for (size_t j = 0; j < d_conv; j++) { x[j] = x_block[tid * stride_x + j]; } } else { @@ -87,7 +87,7 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, } #pragma unroll - for (int j = 0; j < d_conv; j++) { + for (size_t j = 0; j < d_conv; j++) { sumf += x[(i + j) % d_conv] * w[j]; } y_block[i * stride_y + tid] = sumf; @@ -97,8 +97,8 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, - const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t, - const int n_s, cudaStream_t stream) { + const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t, + const int64_t n_s, cudaStream_t stream) { const int threads = 128; GGML_ASSERT(nr % threads == 0); @@ -106,18 +106,16 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); if (nc == 4) { ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, - dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t, - n_s); + dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { GGML_ABORT("Only support kernel size = 4 now."); } } else { if (nc == 4) { - const int split_n_t = 32; - dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); - ssm_conv_long_token_f32 - <<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, - dst_nb1, dst_nb2, nc, ncs, nr, n_t, n_s); + const int64_t split_n_t = 32; + dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); + ssm_conv_long_token_f32<<>>( + src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { GGML_ABORT("Only support kernel size = 4 right now."); } @@ -128,11 +126,10 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; // conv_x const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight - const int nc = src1->ne[0]; // d_conv - const int ncs = src0->ne[0]; // d_conv - 1 + n_t - const int nr = src0->ne[1]; // d_inner - const int n_t = dst->ne[1]; // tokens per sequence - const int n_s = dst->ne[2]; // number of sequences in the batch + const int64_t nc = src1->ne[0]; // d_conv + const int64_t nr = src0->ne[1]; // d_inner + const int64_t n_t = dst->ne[1]; // tokens per sequence + const int64_t n_s = dst->ne[2]; // number of sequences in the batch GGML_ASSERT(dst->ne[0] == nr); GGML_ASSERT(src0->nb[0] == sizeof(float)); @@ -147,5 +144,5 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1], - dst->nb[2], nc, ncs, nr, n_t, n_s, stream); + dst->nb[2], nc, nr, n_t, n_s, stream); } diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index 52db17cd9ae..37ee208c09d 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -1,10 +1,5 @@ #include "ssm-scan.cuh" -// #include -// static __device__ void global_to_shared(const float *src, float *dst) { -// asm volatile("cp.async."); -// } - template __global__ void __launch_bounds__(splitD, 2) ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, @@ -12,7 +7,9 @@ __global__ void __launch_bounds__(splitD, 2) const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, - float * __restrict__ dst, const int D, const int L, const int B) { + float * __restrict__ dst, const int64_t L) { + GGML_UNUSED(src1_nb0); + GGML_UNUSED(src2_nb0); const int bidx = blockIdx.x; // split along B const int bidy = blockIdx.y; // split along D const int tid = threadIdx.x; @@ -25,12 +22,12 @@ __global__ void __launch_bounds__(splitD, 2) float * smem_A = smem; float * smem_s0 = smem_A + splitD * stride_sA; - const float * s0_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1); - const float * x_block = (const float *) ((char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); - const float * dt_block = (const float *) ((char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float)); - const float * A_block = (const float *) ((char *) src3 + bidy * splitD * src3_nb1); - const float * B_block = (const float *) ((char *) src4 + (bidx * src4_nb2)); - const float * C_block = (const float *) ((char *) src5 + (bidx * src5_nb2)); + const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1); + const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); + const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float)); + const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1); + const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2)); + const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2)); float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1); @@ -46,7 +43,7 @@ __global__ void __launch_bounds__(splitD, 2) // can N not be 16? for example 32? if (N == 16) { #pragma unroll - for (int i = 0; i < splitD / 4; i += 2) { + for (size_t i = 0; i < splitD / 4; i += 2) { float value = A_block[(wid * warpSize + i) * stride_A + wtid]; // todo: bank conflict // I am always confused with how to use the swizzling method to solve @@ -54,7 +51,7 @@ __global__ void __launch_bounds__(splitD, 2) smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; } #pragma unroll - for (int i = 0; i < splitD / 4; i += 2) { + for (size_t i = 0; i < splitD / 4; i += 2) { float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid]; smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; } @@ -62,7 +59,7 @@ __global__ void __launch_bounds__(splitD, 2) __syncthreads(); - for (int i = 0; i < L; i++) { + for (int64_t i = 0; i < L; i++) { float dt_soft_plus = dt_block[i * stride_dt + tid]; if (dt_soft_plus <= 20.0f) { dt_soft_plus = log1pf(exp(dt_soft_plus)); @@ -70,7 +67,7 @@ __global__ void __launch_bounds__(splitD, 2) float x_dt = x_block[i * stride_x + tid] * dt_soft_plus; float sumf = 0.0f; #pragma unroll - for (int j = 0; j < N; j++) { + for (size_t j = 0; j < N; j++) { float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) + (B_block[i * stride_B + j] * x_dt); sumf += state * C_block[i * stride_C + j]; @@ -90,7 +87,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, - float * dst, const int N, const int D, const int L, const int B, cudaStream_t stream) { + float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B, + cudaStream_t stream) { const int threads = 128; // todo: consider D cannot be divided,does this situation exist? GGML_ASSERT(D % threads == 0); @@ -99,7 +97,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa if (N == 16) { ssm_scan_f32<128, 16><<>>( src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0, - src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, D, L, B); + src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L); } else { GGML_ABORT("doesn't support N!=16."); } diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 40091a0ef07..ba195e1d100 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -1,3 +1,5 @@ +#pragma once + #include "common.cuh" #include diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 3983ce5b423..1a28831b7a9 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -20,6 +20,7 @@ #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define CUBLAS_TF32_TENSOR_OP_MATH 0 #define CUDA_R_16F HIPBLAS_R_16F +#define CUDA_R_16BF HIPBLAS_R_16B #define CUDA_R_32F HIPBLAS_R_32F #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended @@ -70,6 +71,8 @@ #define cudaLaunchHostFunc hipLaunchHostFunc #define cudaMalloc hipMalloc #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) +#define cudaMallocManaged hipMallocManaged +#define cudaMemAdvise hipMemAdvise #define cudaMemcpy hipMemcpy #define cudaMemcpyAsync hipMemcpyAsync #define cudaMemcpyPeerAsync hipMemcpyPeerAsync diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index f2d55796e78..937779a90af 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -15,6 +15,7 @@ #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS #define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT #define CUDA_R_16F MUSA_R_16F +#define CUDA_R_16BF MUSA_R_16BF #define CUDA_R_32F MUSA_R_32F #define cublasComputeType_t cudaDataType_t #define cublasCreate mublasCreate diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index e3762649fd2..1fe8fe3b8d0 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -89,10 +89,6 @@ endif() add_compile_definitions(GGML_USE_HIP) -if (GGML_HIP_UMA) - add_compile_definitions(GGML_HIP_UMA) -endif() - if (GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index be2e3fc9155..a19cfb14e0f 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -148,8 +148,14 @@ struct ggml_map_custom2_op_params { struct ggml_map_custom3_op_params { ggml_custom3_op_t fun; - int n_tasks; - void * userdata; + int n_tasks; + void * userdata; +}; + +struct ggml_custom_op_params { + ggml_custom_op_t fun; + int n_tasks; + void * userdata; }; // bitset @@ -311,29 +317,28 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size); // FP16 to FP32 conversion -#if defined(__ARM_NEON) - #if defined(_MSC_VER) || (defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) - typedef uint16_t ggml_fp16_internal_t; - #else - typedef __fp16 ggml_fp16_internal_t; - #endif -#endif - -#if defined(__ARM_NEON) && !defined(_MSC_VER) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) +// 16-bit float +// on Arm, we use __fp16 +// on x86, we use uint16_t +// +// for old CUDA compilers (<= 11), we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/10616 +// for MUSA compilers , we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/11843 +// +#if defined(__ARM_NEON) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) && !defined(__MUSACC__) #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) #define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { - ggml_fp16_internal_t tmp; + __fp16 tmp; memcpy(&tmp, &h, sizeof(ggml_fp16_t)); return (float)tmp; } static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { ggml_fp16_t res; - ggml_fp16_internal_t tmp = f; + __fp16 tmp = f; memcpy(&res, &tmp, sizeof(ggml_fp16_t)); return res; } @@ -357,8 +362,8 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size); #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { - register float f; - register double d; + float f; + double d; __asm__( "mtfprd %0,%2\n" "xscvhpdp %0,%0\n" @@ -370,8 +375,8 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size); } static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { - register double d; - register ggml_fp16_t r; + double d; + ggml_fp16_t r; __asm__( /* xscvdphp can work on double or single precision */ "xscvdphp %0,%2\n" "mffprd %1,%0\n" : @@ -485,7 +490,7 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size); #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) -#endif // defined(__ARM_NEON) && (!defined(__MSC_VER) +#endif // defined(__ARM_NEON) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) && !defined(__MUSACC__) // precomputed f32 table for f16 (256 KB) // defined in ggml.c, initialized in ggml_init() diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 456e1fd994c..266d8af4693 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -354,6 +354,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, @@ -362,6 +363,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, @@ -370,6 +372,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, @@ -378,6 +381,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, @@ -386,6 +390,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, @@ -394,6 +399,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, @@ -402,6 +408,14 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, @@ -430,6 +444,13 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, GGML_METAL_KERNEL_TYPE_SET_I32, GGML_METAL_KERNEL_TYPE_SET_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, @@ -460,6 +481,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_SQRT, GGML_METAL_KERNEL_TYPE_SIN, GGML_METAL_KERNEL_TYPE_COS, + GGML_METAL_KERNEL_TYPE_NEG, GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, @@ -1011,6 +1033,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat); @@ -1019,6 +1042,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm); @@ -1027,6 +1051,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm); @@ -1035,6 +1060,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm); @@ -1043,6 +1069,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm); @@ -1051,6 +1078,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm); @@ -1059,6 +1087,14 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction); @@ -1087,6 +1123,13 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); @@ -1117,6 +1160,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); @@ -1278,6 +1322,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_NEG: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; @@ -1334,8 +1379,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex return op->src[0]->type == GGML_TYPE_F16; case GGML_OP_POOL_1D: return false; - case GGML_OP_POOL_2D: case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; + case GGML_OP_POOL_2D: case GGML_OP_PAD: case GGML_OP_PAD_REFLECT_1D: case GGML_OP_TIMESTEP_EMBEDDING: @@ -1345,6 +1391,16 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_ARANGE: return true; case GGML_OP_FLASH_ATTN_EXT: + if (op->src[0]->ne[0] == 32) { + // head size == 32 (e.g. bert-bge-small) + // TODO: not sure if it is worth adding kernels for this size + return false; + } + if (op->src[0]->ne[0] == 576) { + // DeepSeek sizes + // TODO: disabled for now, until optmized + return false; + } if (op->src[1]->type != op->src[2]->type) { return false; } @@ -1957,6 +2013,18 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_UNARY_OP_NEG: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; default: { GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); @@ -3837,12 +3905,14 @@ static void ggml_metal_encode_node( // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) // for now avoiding mainly to keep the number of templates/kernels a bit lower // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612 - if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192)) { + if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) { switch (src1->type) { case GGML_TYPE_F16: { if (ne00 == 192 && ne20 == 128) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline; } else { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; @@ -3865,6 +3935,8 @@ static void ggml_metal_encode_node( { if (ne00 == 192 && ne20 == 128) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline; } else { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break; @@ -3887,6 +3959,8 @@ static void ggml_metal_encode_node( { if (ne00 == 192 && ne20 == 128) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline; } else { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break; @@ -3909,6 +3983,8 @@ static void ggml_metal_encode_node( { if (ne00 == 192 && ne20 == 128) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline; } else { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break; @@ -3931,6 +4007,8 @@ static void ggml_metal_encode_node( { if (ne00 == 192 && ne20 == 128) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline; } else { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break; @@ -3953,6 +4031,8 @@ static void ggml_metal_encode_node( { if (ne00 == 192 && ne20 == 128) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline; } else { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break; @@ -3975,6 +4055,8 @@ static void ggml_metal_encode_node( { if (ne00 == 192 && ne20 == 128) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline; } else { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break; @@ -4004,6 +4086,24 @@ static void ggml_metal_encode_node( use_vec_kernel = true; switch (ne00) { + case 96: + { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } break; case 128: { switch (src1->type) { @@ -4076,12 +4176,36 @@ static void ggml_metal_encode_node( } } } break; + case 576: + { + if (ne20 == 512) { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } else { + GGML_LOG_ERROR("unsupported size: %lld\n", ne20); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } break; default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } } } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index b08666e2799..9f4147e9397 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -949,6 +949,13 @@ kernel void kernel_cos( dst[tpig] = cos(src0[tpig]); } +kernel void kernel_neg( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = -src0[tpig]; +} + kernel void kernel_sum_rows( device const float * src0, device float * dst, @@ -3185,7 +3192,7 @@ kernel void kernel_flash_attn_ext( { float S[Q] = { [0 ... Q-1] = 0.0f }; - float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; + float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 }; // thread indices inside the simdgroup // TODO: see if we can utilize quad-group functions for better performance @@ -3445,7 +3452,7 @@ kernel void kernel_flash_attn_ext( // reduce the warps sequentially for (ushort sg = 1; sg < nsg; ++sg) { float S = { 0.0f }; - float M = { -__FLT16_MAX__/2 }; + float M = { -__FLT_MAX__/2 }; threadgroup_barrier(mem_flags::mem_threadgroup); @@ -3546,6 +3553,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_USE_BF16) template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -3556,6 +3564,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -3566,6 +3575,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -3575,6 +3585,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -3584,6 +3595,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -3593,6 +3605,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -3602,6 +3615,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES @@ -3685,7 +3699,7 @@ kernel void kernel_flash_attn_ext_vec( { float S = 0.0f; - float M = -__FLT16_MAX__/2; + float M = -__FLT_MAX__/2; // thread indices inside the simdgroup const short tx = tiisg%NL; @@ -3959,6 +3973,16 @@ kernel void kernel_flash_attn_ext_vec( typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; +template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; @@ -3999,6 +4023,16 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + #undef FA_TYPES template diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 624cb1b9d08..352deb321ec 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -54,16 +54,41 @@ function(ggml_opencl_add_kernel KNAME) endfunction() set(GGML_OPENCL_KERNELS - ggml-opencl - ggml-opencl_mm - ggml-opencl_cvt - ggml-opencl_gemv_noshuffle - ggml-opencl_gemv_noshuffle_general - ggml-opencl_mul_mat_Ab_Bi_8x4 - ggml-opencl_transpose_16 - ggml-opencl_transpose_32 - ggml-opencl_transpose_32_16 - ggml-opencl_im2col + add + clamp + cpy + cvt + diag_mask_inf + gelu + gemv_noshuffle_general + gemv_noshuffle + get_rows + im2col_f32 + im2col_f16 + mul_mat_Ab_Bi_8x4 + mul_mv_f16_f16 + mul_mv_f16_f32_1row + mul_mv_f16_f32_l4 + mul_mv_f16_f32 + mul_mv_f32_f32 + mul_mv_q4_0_f32 + mul_mv_q4_0_f32_v + mul_mv_q4_0_f32_8x_flat + mul_mv_q4_0_f32_1d_8x_flat + mul_mv_q4_0_f32_1d_16x_flat + mul_mv_q6_k + mul + norm + relu + rms_norm + rope + scale + silu + softmax_4_f32 + softmax_4_f16 + softmax_f32 + softmax_f16 + transpose ) foreach (K ${GGML_OPENCL_KERNELS}) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 723cab8b174..05a2f4e630a 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -64,11 +64,33 @@ enum ADRENO_GPU_GEN { X1E, }; +enum ADRENO_CL_COMPILER_TYPE { + E031, + DX, +}; + struct ggml_cl_version { cl_uint major = 0; cl_uint minor = 0; }; +struct ggml_cl_compiler_version { + ADRENO_CL_COMPILER_TYPE type; + int major = -1; + int minor = -1; + int patch = -1; + + bool same(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const { + return major == x && minor == y && patch == z && type == t; + } + bool newer_than(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const { + return major*10000 + minor*100 + patch > x*10000 + y*100 + z && type == t; + } + bool newer_than_or_same(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const { + return same(t, x, y, z) || newer_than(t, x, y, z); + } +}; + // Parses a version string of form "XX.YY ". On an error returns ggml_cl_version with all zeroes. static ggml_cl_version parse_cl_version(std::string_view str) { size_t major_str_begin = 0; @@ -173,24 +195,30 @@ static ADRENO_GPU_GEN get_adreno_gpu_gen(const char *device_name) { return ADRENO_GPU_GEN::ADRENO_UNKNOWN; } -static int get_adreno_cl_compiler_version(const char *driver_version) { +static ggml_cl_compiler_version get_adreno_cl_compiler_version(const char *driver_version) { std::string driver_ver_str(driver_version); + ADRENO_CL_COMPILER_TYPE type = ADRENO_CL_COMPILER_TYPE::E031; size_t compiler_ver_pos = driver_ver_str.find("E031"); size_t compiler_ver_len = 13; - size_t compiler_ver_offset = 5; + size_t compiler_major_offset = 5; + size_t compiler_minor_offset = 8; + size_t compiler_patch_offset = 11; if (compiler_ver_pos == std::string::npos) { compiler_ver_pos = driver_ver_str.find("DX"); if (compiler_ver_pos == std::string::npos) { - return -1; + return {}; } + type = ADRENO_CL_COMPILER_TYPE::DX; compiler_ver_len = 11; - compiler_ver_offset = 3; + compiler_major_offset = 3; } std::string compiler_ver_str = driver_ver_str.substr(compiler_ver_pos, compiler_ver_len); - std::string major_ver_str = compiler_ver_str.substr(compiler_ver_offset, 2); - return std::atoi(major_ver_str.c_str()); + int major = std::atoi(compiler_ver_str.substr(compiler_major_offset, 2).c_str()); + int minor = std::atoi(compiler_ver_str.substr(compiler_minor_offset, 2).c_str()); + int patch = std::atoi(compiler_ver_str.substr(compiler_patch_offset, 2).c_str()); + return { type, major, minor, patch }; } // backend device context @@ -215,16 +243,48 @@ struct ggml_backend_opencl_context { cl_int alignment; size_t max_alloc_size; bool fp16_support; + bool has_vector_subgroup_broadcast; + ggml_cl_compiler_version adreno_cl_compiler_version; int adreno_wave_size; cl_context context; cl_command_queue queue; - cl_program program; - cl_program program_1; - cl_program program_2; - cl_program program_im2col; + cl_program program_add; + cl_program program_clamp; + cl_program program_cpy; + cl_program program_cvt; + cl_program program_diag_mask_inf; + cl_program program_gelu; + cl_program program_gemv_noshuffle_general; + cl_program program_gemv_noshuffle; + cl_program program_get_rows; + cl_program program_im2col_f16; + cl_program program_im2col_f32; + cl_program program_mul_mat_Ab_Bi_8x4; + cl_program program_mul_mv_q4_0_f32; + cl_program program_mul_mv_q4_0_f32_v; + cl_program program_mul_mv_q4_0_f32_8x_flat; + cl_program program_mul_mv_q4_0_f32_1d_8x_flat; + cl_program program_mul_mv_q4_0_f32_1d_16x_flat; + cl_program program_mul_mv_q6_K; + cl_program program_mul_mv_f16_f16; + cl_program program_mul_mv_f16_f32_1row; + cl_program program_mul_mv_f16_f32_l4; + cl_program program_mul_mv_f16_f32; + cl_program program_mul_mv_f32_f32; + cl_program program_mul; + cl_program program_norm; + cl_program program_relu; + cl_program program_rms_norm; + cl_program program_rope; + cl_program program_scale; + cl_program program_silu; + cl_program program_softmax_f32; + cl_program program_softmax_f16; + cl_program program_softmax_4_f32; + cl_program program_softmax_4_f16; cl_kernel kernel_add, kernel_add_row; cl_kernel kernel_mul, kernel_mul_row; @@ -249,19 +309,17 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_f16_f32; cl_kernel kernel_mul_mat_f16_f32_l4; cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; - cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0, kernel_mul_mat_q4_0_f32_flat; + cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; - cl_kernel kernel_convert_block_q4_0_noshuffle, kernel_mul_mat_q4_0_f32_flat_v0, - kernel_mul_mat_q4_0_f32_flat_img_v0; + cl_kernel kernel_convert_block_q4_0_noshuffle; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_im2col_f32, kernel_im2col_f16; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Transpose kernels - cl_program program_transpose_32; - cl_program program_transpose_32_16; - cl_program program_transpose_16; + cl_program program_transpose; + cl_kernel kernel_transpose_32; cl_kernel kernel_transpose_32_16; cl_kernel kernel_transpose_16; @@ -374,6 +432,681 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co return p; } +static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_version opencl_c_version) { + cl_int err; + + // compiler options for general kernels + auto opencl_c_std = + std::string("CL") + std::to_string(opencl_c_version.major) + "." + std::to_string(opencl_c_version.minor); + std::string compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable -cl-unsafe-math-optimizations" + " -cl-finite-math-only -cl-fast-relaxed-math"; + + GGML_LOG_INFO("ggml_opencl: loading OpenCL kernels"); + + // add + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "add.cl.h" + }; +#else + const std::string kernel_src = read_file("add.cl"); +#endif + backend_ctx->program_add = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_add = clCreateKernel(backend_ctx->program_add, "kernel_add", &err), err)); + CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program_add, "kernel_add_row", &err), err)); + GGML_LOG_CONT("."); + } + + // clamp + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "clamp.cl.h" + }; +#else + const std::string kernel_src = read_file("clamp.cl"); +#endif + backend_ctx->program_clamp = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program_clamp, "kernel_clamp", &err), err)); + GGML_LOG_CONT("."); + } + + // cpy + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "cpy.cl.h" + }; +#else + const std::string kernel_src = read_file("cpy.cl"); +#endif + backend_ctx->program_cpy = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // cvt + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "cvt.cl.h" + }; +#else + const std::string kernel_src = read_file("cvt.cl"); +#endif + backend_ctx->program_cvt = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err)); + GGML_LOG_CONT("."); + } + + // diag_mask_inf + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "diag_mask_inf.cl.h" + }; +#else + const std::string kernel_src = read_file("diag_mask_inf.cl"); +#endif + backend_ctx->program_diag_mask_inf = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel(backend_ctx->program_diag_mask_inf, "kernel_diag_mask_inf_8", &err), err)); + CL_CHECK((backend_ctx->kernel_diag_mask_inf = clCreateKernel(backend_ctx->program_diag_mask_inf, "kernel_diag_mask_inf", &err), err)); + GGML_LOG_CONT("."); + } + + // gelu + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gelu.cl.h" + }; +#else + const std::string kernel_src = read_file("gelu.cl"); +#endif + backend_ctx->program_gelu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_4", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick_4", &err), err)); + GGML_LOG_CONT("."); + } + + // get_rows + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "get_rows.cl.h" + }; +#else + const std::string kernel_src = read_file("get_rows.cl"); +#endif + backend_ctx->program_get_rows = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_get_rows_f32 = clCreateKernel(backend_ctx->program_get_rows, "kernel_get_rows_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_get_rows_f16 = clCreateKernel(backend_ctx->program_get_rows, "kernel_get_rows_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_get_rows_q4_0 = clCreateKernel(backend_ctx->program_get_rows, "kernel_get_rows_q4_0", &err), err)); + GGML_LOG_CONT("."); + } + + // im2col_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "im2col_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("im2col_f32.cl"); +#endif + backend_ctx->program_im2col_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_im2col_f32 = clCreateKernel(backend_ctx->program_im2col_f32, "kernel_im2col_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // im2col_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "im2col_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("im2col_f16.cl"); +#endif + backend_ctx->program_im2col_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_im2col_f16 = clCreateKernel(backend_ctx->program_im2col_f16, "kernel_im2col_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32 = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32, "kernel_mul_mat_q4_0_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_v + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_v.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_v.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_v = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_v = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_v, "kernel_mul_mat_q4_0_f32_v", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_8x_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_8x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_8x_flat.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_8x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_8x_flat, "kernel_mul_mat_q4_0_f32_8x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_1d_8x_flat + // This kernel does not compiler on Adreno cl compiler 38.01. Skip it for + // those compiler versions since it is anyway not used for Adreno. + if (backend_ctx->gpu_family != ADRENO || + backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) || + backend_ctx->adreno_cl_compiler_version.type == DX) { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_1d_8x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_1d_8x_flat.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_1d_8x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_1d_8x_flat, "kernel_mul_mat_q4_0_f32_1d_8x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_1d_16x_flat + // This kernel does not compiler on Adreno cl compiler 38.01. Skip it for + // those compiler versions since it is anyway not used for Adreno. + if (backend_ctx->gpu_family != ADRENO || + backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) || + backend_ctx->adreno_cl_compiler_version.type == DX) { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_1d_16x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_1d_16x_flat.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_1d_16x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_1d_16x_flat, "kernel_mul_mat_q4_0_f32_1d_16x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q6_k + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q6_k.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q6_k.cl"); +#endif + backend_ctx->program_mul_mv_q6_K = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32 = clCreateKernel(backend_ctx->program_mul_mv_q6_K, "kernel_mul_mv_q6_K_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f16.cl"); +#endif + backend_ctx->program_mul_mv_f16_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f16 = clCreateKernel(backend_ctx->program_mul_mv_f16_f16, "kernel_mul_mat_f16_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f32_1row + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f32_1row.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f32_1row.cl"); +#endif + backend_ctx->program_mul_mv_f16_f32_1row = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_1row = clCreateKernel(backend_ctx->program_mul_mv_f16_f32_1row, "kernel_mul_mat_f16_f32_1row", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f32_l4 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f32_l4.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f32_l4.cl"); +#endif + backend_ctx->program_mul_mv_f16_f32_l4 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_l4 = clCreateKernel(backend_ctx->program_mul_mv_f16_f32_l4, "kernel_mul_mat_f16_f32_l4", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f32.cl"); +#endif + backend_ctx->program_mul_mv_f16_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32 = clCreateKernel(backend_ctx->program_mul_mv_f16_f32, "kernel_mul_mat_f16_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f32_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f32_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f32_f32.cl"); +#endif + backend_ctx->program_mul_mv_f32_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f32_f32 = clCreateKernel(backend_ctx->program_mul_mv_f32_f32, "kernel_mul_mat_f32_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul.cl.h" + }; +#else + const std::string kernel_src = read_file("mul.cl"); +#endif + backend_ctx->program_mul = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program_mul, "kernel_mul", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row", &err), err)); + GGML_LOG_CONT("."); + } + + // norm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "norm.cl.h" + }; +#else + const std::string kernel_src = read_file("norm.cl"); +#endif + backend_ctx->program_norm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err)); + GGML_LOG_CONT("."); + } + + // relu + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "relu.cl.h" + }; +#else + const std::string kernel_src = read_file("relu.cl"); +#endif + backend_ctx->program_relu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program_relu, "kernel_relu", &err), err)); + GGML_LOG_CONT("."); + } + + // rms_norm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "rms_norm.cl.h" + }; +#else + const std::string kernel_src = read_file("rms_norm.cl"); +#endif + backend_ctx->program_rms_norm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err)); + GGML_LOG_CONT("."); + } + + // rope + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "rope.cl.h" + }; +#else + const std::string kernel_src = read_file("rope.cl"); +#endif + backend_ctx->program_rope = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_rope_norm_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_norm_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_norm_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_neox_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_neox_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_multi_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_multi_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_multi_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_multi_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_vision_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_vision_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_vision_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_vision_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // scale + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "scale.cl.h" + }; +#else + const std::string kernel_src = read_file("scale.cl"); +#endif + backend_ctx->program_scale = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_scale = clCreateKernel(backend_ctx->program_scale, "kernel_scale", &err), err)); + GGML_LOG_CONT("."); + } + + // silu + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "silu.cl.h" + }; +#else + const std::string kernel_src = read_file("silu.cl"); +#endif + backend_ctx->program_silu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_silu = clCreateKernel(backend_ctx->program_silu, "kernel_silu", &err), err)); + CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program_silu, "kernel_silu_4", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_f32.cl"); +#endif + backend_ctx->program_softmax_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program_softmax_f32, "kernel_soft_max", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_f16.cl"); +#endif + backend_ctx->program_softmax_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max_f16 = clCreateKernel(backend_ctx->program_softmax_f16, "kernel_soft_max_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_4_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_4_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_4_f32.cl"); +#endif + backend_ctx->program_softmax_4_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program_softmax_4_f32, "kernel_soft_max_4", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_4_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_4_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_4_f16.cl"); +#endif + backend_ctx->program_softmax_4_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel(backend_ctx->program_softmax_4_f16, "kernel_soft_max_4_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // Adreno kernels +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // transpose + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "transpose.cl.h" + }; +#else + const std::string kernel_src = read_file("transpose.cl"); +#endif + backend_ctx->program_transpose = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_16", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16", &err), err)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_general + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv_general { + #include "gemv_noshuffle_general.cl.h" + }; +#else + const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_general.cl"); +#endif + + backend_ctx->program_CL_gemv_general = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle + { + // Gemv 2048, 16384 + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=2048 " + " -DBLOCK_STRIDE_A=16384 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv { + #include "gemv_noshuffle.cl.h" + }; +#else + const std::string kernel_src_CL_gemv = read_file("gemv_noshuffle.cl"); +#endif + + backend_ctx->program_CL_gemv_4096_1_4096 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + + // Gemv 2048, 16384 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=2048 " + " -DBLOCK_STRIDE_A=16384 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_4096_1_11008 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + + // Gemv 5504, 44032 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=5504 " + " -DBLOCK_STRIDE_A=44032 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_11008_1_4096 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + + // Gemv 16000, 128000 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=16000 " + " -DBLOCK_STRIDE_A=128000 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_32000_1_4096 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_32000_1_4096, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mat_Ab_Bi_8x4 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemm { + #include "mul_mat_Ab_Bi_8x4.cl.h" + }; +#else + const std::string kernel_src_CL_gemm = read_file("mul_mat_Ab_Bi_8x4.cl"); +#endif + backend_ctx->program_CL_gemm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_CL_gemm.c_str(), compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err)); + GGML_LOG_CONT("."); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + GGML_LOG_CONT("\n"); +} + static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { static bool initialized = false; static ggml_backend_opencl_context *backend_ctx = nullptr; @@ -415,6 +1148,7 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { unsigned number; cl_device_type type; char name[128]; + char version[128]; }; enum { NPLAT = 16, NDEV = 16 }; @@ -455,6 +1189,7 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { d->platform = p; CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_NAME, sizeof(d->name), &d->name, NULL)); CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_TYPE, sizeof(d->type), &d->type, NULL)); + CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_VERSION, sizeof(d->version), &d->version, NULL)); if (p->default_device == NULL && d->type == CL_DEVICE_TYPE_GPU) { p->default_device = d; @@ -547,7 +1282,7 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { } GGML_LOG_INFO("ggml_opencl: selecting platform: '%s'\n", default_device->platform->name); - GGML_LOG_INFO("ggml_opencl: selecting device: '%s'\n", default_device->name); + GGML_LOG_INFO("ggml_opencl: selecting device: '%s (%s)'\n", default_device->name, default_device->version); if (default_device->type != CL_DEVICE_TYPE_GPU) { GGML_LOG_WARN("ggml_opencl: warning, not a GPU: '%s'.\n", default_device->name); } @@ -556,9 +1291,15 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { dev_ctx->device = default_device->id; backend_ctx->device = default_device->id; - if (strstr(default_device->name, "Adreno")) { + if (strstr(default_device->name, "Adreno") || + strstr(default_device->name, "Qualcomm") || + strstr(default_device->version, "Adreno")) { backend_ctx->gpu_family = GPU_FAMILY::ADRENO; - backend_ctx->adreno_gen = get_adreno_gpu_gen(default_device->name); + // Usually device version contains the detailed device name + backend_ctx->adreno_gen = get_adreno_gpu_gen(default_device->version); + if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) { + backend_ctx->adreno_gen = get_adreno_gpu_gen(default_device->name); + } // Use wave size of 64 for all Adreno GPUs. backend_ctx->adreno_wave_size = 64; @@ -604,11 +1345,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { GGML_LOG_INFO("ggml_opencl: OpenCL driver: %s\n", driver_version); backend_ctx->driver_version = driver_version; - int adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version); - bool has_vector_subgroup_broadcast = - adreno_cl_compiler_version >= 47 || adreno_cl_compiler_version == 17; + backend_ctx->adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version); + backend_ctx->has_vector_subgroup_broadcast = + backend_ctx->adreno_cl_compiler_version.major >= 47 || + backend_ctx->adreno_cl_compiler_version.major == 17; GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n", - has_vector_subgroup_broadcast ? "true" : "false"); + backend_ctx->has_vector_subgroup_broadcast ? "true" : "false"); size_t ext_str_size; clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size); @@ -683,268 +1425,28 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { #endif CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err)); -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src { - #include "ggml-opencl.cl.h" - }; -#else - const std::string kernel_src = read_file("ggml-opencl.cl"); -#endif - - auto opencl_c_std = - std::string("CL") + std::to_string(opencl_c_version.major) + "." + std::to_string(opencl_c_version.minor); - - std::string compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable -cl-unsafe-math-optimizations" - " -cl-finite-math-only -cl-fast-relaxed-math"; - backend_ctx->program = build_program_from_source(context, device, kernel_src.c_str(), compile_opts); - - // Non matmul kernels. - CL_CHECK((backend_ctx->kernel_get_rows_f32 = clCreateKernel(backend_ctx->program, "kernel_get_rows_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_get_rows_f16 = clCreateKernel(backend_ctx->program, "kernel_get_rows_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_get_rows_q4_0 = clCreateKernel(backend_ctx->program, "kernel_get_rows_q4_0", &err), err)); - CL_CHECK((backend_ctx->kernel_add = clCreateKernel(backend_ctx->program, "kernel_add", &err), err)); - CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program, "kernel_add_row", &err), err)); - CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program, "kernel_mul", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program, "kernel_mul_row", &err), err)); - CL_CHECK((backend_ctx->kernel_scale = clCreateKernel(backend_ctx->program, "kernel_scale", &err), err)); - CL_CHECK((backend_ctx->kernel_silu = clCreateKernel(backend_ctx->program, "kernel_silu", &err), err)); - CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program, "kernel_silu_4", &err), err)); - CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program, "kernel_gelu", &err), err)); - CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_4", &err), err)); - CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program, "kernel_gelu_quick", &err), err)); - CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_quick_4", &err), err)); - CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program, "kernel_relu", &err), err)); - CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program, "kernel_clamp", &err), err)); - CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program, "kernel_norm", &err), err)); - CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program, "kernel_rms_norm", &err), err)); - CL_CHECK((backend_ctx->kernel_diag_mask_inf = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf", &err), err)); - CL_CHECK((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf_8", &err), err)); - CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program, "kernel_soft_max", &err), err)); - CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4", &err), err)); - CL_CHECK((backend_ctx->kernel_soft_max_f16 = clCreateKernel(backend_ctx->program, "kernel_soft_max_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_norm_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_multi_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_multi_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_vision_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_vision_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f32", &err), err)); - - // Matmul kernels. - CL_CHECK((backend_ctx->kernel_mul_mat_f32_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f32_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_1row = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32_1row", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_l4 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32_l4", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_v = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_v", &err), err)); - - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_flat", &err), err)); - CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program, "kernel_convert_block_q4_0", &err), err)); - CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program, "kernel_restore_block_q4_0", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_8x_flat", &err), err)); - - // Load additional mulmat kernels. -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_1 { - #include "ggml-opencl_mm.cl.h" - }; -#else - const std::string kernel_src_1 = read_file("ggml-opencl_mm.cl"); -#endif - backend_ctx->program_1 = build_program_from_source(context, device, kernel_src_1.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_1d_8x_flat", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_1d_16x_flat", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mv_q6_K_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat_v0 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_flat_v0", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat_img_v0 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_flat_img_v0", &err), err)); - - // Load additional data conversion kernels. -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_2 { - #include "ggml-opencl_cvt.cl.h" - }; -#else - const std::string kernel_src_2 = read_file("ggml-opencl_cvt.cl"); -#endif - backend_ctx->program_2 = build_program_from_source(context, device, kernel_src_2.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_2, "kernel_convert_block_q4_0_noshuffle", &err), err)); + // Load kernels + load_cl_kernels(backend_ctx, opencl_c_version); - // im2col kernels -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_im2col { - #include "ggml-opencl_im2col.cl.h" - }; -#else - const std::string kernel_src_im2col = read_file("ggml-opencl_im2col.cl"); -#endif - backend_ctx->program_im2col = build_program_from_source(context, device, kernel_src_im2col.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_im2col_f32 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_im2col_f16 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f16", &err), err)); - - // Kernels for Adreno #ifdef GGML_OPENCL_USE_ADRENO_KERNELS -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string transpose_32_src { - #include "ggml-opencl_transpose_32.cl.h" - }; -#else - const std::string transpose_32_src = read_file("ggml-opencl_transpose_32.cl"); -#endif - backend_ctx->program_transpose_32 = build_program_from_source(context, device, transpose_32_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose_32, "kernel_transpose_32", &err), err)); - -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string transpose_32_16_src { - #include "ggml-opencl_transpose_32_16.cl.h" - }; -#else - const std::string transpose_32_16_src = read_file("ggml-opencl_transpose_32_16.cl"); -#endif - backend_ctx->program_transpose_32_16 = build_program_from_source(context, device, transpose_32_16_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose_32_16, "kernel_transpose_32_16", &err), err)); - -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string transpose_16_src { - #include "ggml-opencl_transpose_16.cl.h" - }; -#else - const std::string transpose_16_src = read_file("ggml-opencl_transpose_16.cl"); -#endif - backend_ctx->program_transpose_16 = build_program_from_source(context, device, transpose_16_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose_16, "kernel_transpose_16", &err), err)); - - // Gemv general - std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_CL_gemv_general { - #include "ggml-opencl_gemv_noshuffle_general.cl.h" - }; -#else - const std::string kernel_src_CL_gemv_general = read_file("ggml-opencl_gemv_noshuffle_general.cl"); -#endif - - backend_ctx->program_CL_gemv_general = build_program_from_source( - context, device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, "kernel_gemv_noshuffle", &err), err)); - - // Gemv 2048, 16384 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=2048 " - " -DBLOCK_STRIDE_A=16384 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_CL_gemv { - #include "ggml-opencl_gemv_noshuffle.cl.h" - }; -#else - const std::string kernel_src_CL_gemv = read_file("ggml-opencl_gemv_noshuffle.cl"); -#endif - - backend_ctx->program_CL_gemv_4096_1_4096 = build_program_from_source( - context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, "kernel_gemv_noshuffle", &err), err)); - - // Gemv 2048, 16384 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=2048 " - " -DBLOCK_STRIDE_A=16384 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } - - backend_ctx->program_CL_gemv_4096_1_11008 = build_program_from_source( - context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, "kernel_gemv_noshuffle", &err), err)); - - // Gemv 5504, 44032 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=5504 " - " -DBLOCK_STRIDE_A=44032 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } - - backend_ctx->program_CL_gemv_11008_1_4096 = build_program_from_source( - context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, "kernel_gemv_noshuffle", &err), err)); - - // Gemv 16000, 128000 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=16000 " - " -DBLOCK_STRIDE_A=128000 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } - - backend_ctx->program_CL_gemv_32000_1_4096 = build_program_from_source(context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_32000_1_4096, "kernel_gemv_noshuffle", &err), err)); - - // Gemm -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_CL_gemm { - #include "ggml-opencl_mul_mat_Ab_Bi_8x4.cl.h" - }; -#else - const std::string kernel_src_CL_gemm = read_file("ggml-opencl_mul_mat_Ab_Bi_8x4.cl"); -#endif - backend_ctx->program_CL_gemm = build_program_from_source(context, device, kernel_src_CL_gemm.c_str(), compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err)); - - // TODO: fixme: these sizes are hardcoded for now. - // they should be allocated based on the model's size - // and the device's max alloc size - size_t max_alloc_size; - CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &max_alloc_size, NULL)); - // Allocate intermediate buffers and images size_t required_A_q_d_bytes = 311164928; size_t required_A_s_d_bytes = 38895616; size_t required_B_d_bytes = 45088768; // Ensure buffer sizes do not exceed the maximum allocation size - size_t max_A_q_d_bytes = MIN(required_A_q_d_bytes, max_alloc_size); - size_t max_A_s_d_bytes = MIN(required_A_s_d_bytes, max_alloc_size); - size_t max_B_d_bytes = MIN(required_B_d_bytes, max_alloc_size); - if (required_A_q_d_bytes > max_alloc_size) { + size_t max_A_q_d_bytes = MIN(required_A_q_d_bytes, backend_ctx->max_alloc_size); + size_t max_A_s_d_bytes = MIN(required_A_s_d_bytes, backend_ctx->max_alloc_size); + size_t max_B_d_bytes = MIN(required_B_d_bytes, backend_ctx->max_alloc_size); + if (required_A_q_d_bytes > backend_ctx->max_alloc_size) { GGML_LOG_WARN("ggml_opencl: A_q_d buffer size reduced from %zu to %zu due to device limitations.\n", required_A_q_d_bytes, max_A_q_d_bytes); } - if (required_A_s_d_bytes > max_alloc_size) { + if (required_A_s_d_bytes > backend_ctx->max_alloc_size) { GGML_LOG_WARN("ggml_opencl: A_s_d buffer size reduced from %zu to %zu due to device limitations.\n", required_A_s_d_bytes, max_A_s_d_bytes); } - if (required_B_d_bytes > max_alloc_size) { + if (required_B_d_bytes > backend_ctx->max_alloc_size) { GGML_LOG_WARN("ggml_opencl: B_d buffer size reduced from %zu to %zu due to device limitations.\n", required_B_d_bytes, max_B_d_bytes); } @@ -1019,7 +1521,7 @@ static void ggml_cl2_free(void) { info.cmd_complete_duration_ns/1.e6f, info.cmd_total_duration_ns/1.e6f, info.global_size[0], info.global_size[1], info.global_size[2], - info.local_size[0], info.local_size[2], info.local_size[2], + info.local_size[0], info.local_size[1], info.local_size[2], info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]); } fclose(fperf); @@ -1490,8 +1992,15 @@ static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buff // The optimized gemm and gemv kernels are used for large matrices without batch. // tensor is the quantized weights matrix. -inline bool use_adreno_kernels(const ggml_tensor *tensor) { - return tensor->ne[0] >= 512 && tensor->ne[1] >= 512 && +inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + int64_t threshold_ne0 = 512; + int64_t threshold_ne1 = 512; + if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) && + backend_ctx->adreno_cl_compiler_version.type != DX) { + threshold_ne0 = 128; + threshold_ne1 = 128; + } + return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; } @@ -1569,7 +2078,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; // The optimized kernels need weights in natural order, so unshuffle. - if (use_adreno_kernels(tensor)) { + if (use_adreno_kernels(backend_ctx, tensor)) { kernel = backend_ctx->kernel_convert_block_q4_0_noshuffle; } #else @@ -1593,7 +2102,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Only do transpose for large, non batched matrix // TODO: use preallocated images instead of sub-buffer then image - if (use_adreno_kernels(tensor)) { + if (use_adreno_kernels(backend_ctx, tensor)) { // <----------------------------------------------------------------------------------> // // start transpose // <----------------------------------------------------------------------------------> // @@ -2894,8 +3403,8 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; cl_command_queue queue = backend_ctx->queue; - ggml_backend_opencl_device_context * dev_ctx = - (ggml_backend_opencl_device_context *)backend->device->context; + //ggml_backend_opencl_device_context * dev_ctx = + // (ggml_backend_opencl_device_context *)backend->device->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; @@ -2926,13 +3435,20 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c // Note, this kernel declares local memory in kernel args and the size // depends on subgroup size. - // Retrieve subgroup size. // Note, this requires OpenCL 2.1 and above + // For now we use fixed subgroup size to simplify support for OpenCL 2.0. size_t sgs; - CL_CHECK(clGetKernelSubGroupInfo(kernel, dev_ctx->device, - CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, - sizeof(local_work_size), local_work_size, - sizeof(size_t), &sgs, NULL)); + //CL_CHECK(clGetKernelSubGroupInfo(kernel, dev_ctx->device, + // CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, + // sizeof(local_work_size), local_work_size, + // sizeof(size_t), &sgs, NULL)); + if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + } else if (backend_ctx->gpu_family == INTEL) { + sgs = 32; + } else { + GGML_ASSERT(false && "Unsupported GPU"); + } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -3025,7 +3541,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_context context = backend_ctx->context; - if (ne01 && ne1 && use_adreno_kernels(src0)) { + if (ne01 && ne1 && use_adreno_kernels(backend_ctx, src0)) { // init CL objects // <--------------------------------------------> // diff --git a/ggml/src/ggml-opencl/kernels/add.cl b/ggml/src/ggml-opencl/kernels/add.cl new file mode 100644 index 00000000000..f73f3c01343 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/add.cl @@ -0,0 +1,83 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// add +//------------------------------------------------------------------------------ + +// general-purpose kernel for addition of two tensors +// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 +// cons: not very efficient +kernel void kernel_add( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] + src1[idx1]; +} diff --git a/ggml/src/ggml-opencl/kernels/clamp.cl b/ggml/src/ggml-opencl/kernels/clamp.cl new file mode 100644 index 00000000000..ae6032444e8 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/clamp.cl @@ -0,0 +1,20 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// clamp +//------------------------------------------------------------------------------ +kernel void kernel_clamp( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + float min, + float max +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = src0[get_global_id(0)] < min ? + min : + (src0[get_global_id(0)] > max ? max : src0[get_global_id(0)]); +} diff --git a/ggml/src/ggml-opencl/kernels/cpy.cl b/ggml/src/ggml-opencl/kernels/cpy.cl new file mode 100644 index 00000000000..9369351a60c --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/cpy.cl @@ -0,0 +1,184 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// cpy +//------------------------------------------------------------------------------ + +kernel void kernel_cpy_f16_f16( + global half * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f16_f32( + global half * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + + src0 = (global half*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f16( + global float * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl similarity index 69% rename from ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl rename to ggml/src/ggml-opencl/kernels/cvt.cl index e2024332f81..fe7975e3dbf 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -1,39 +1,20 @@ //------------------------------------------------------------------------------ -// This file is contains additional kernels for data conversion. +// This file is contains kernels for data conversion. // These kernels are used when loading the model, so its performance is less // important. //------------------------------------------------------------------------------ -#ifdef cl_khr_fp16 #pragma OPENCL EXTENSION cl_khr_fp16 : enable -#elif defined(cl_amd_fp16) -#pragma OPENCL EXTENSION cl_amd_fp16 : enable -#else -#error "Half precision floating point not supportedby OpenCL implementation on your device." -#endif - -#ifdef cl_khr_subgroups -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#elif defined(cl_intel_subgroups) -#pragma OPENCL EXTENSION cl_intel_subgroups : enable -#else -#error "Subgroup not supported on your device." -#endif #ifdef cl_intel_required_subgroup_size -// Always use subgroup size of 32 on Intel. #pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable #define INTEL_GPU 1 #define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) #define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) #elif defined(cl_qcom_reqd_sub_group_size) -// Always use subgroups size of 64 on Adreno. #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable #define ADRENO_GPU 1 #define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) #define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#else -// TODO: do not know how to choose subgroup size on other GPUs. -#error "Selecting subgroup size is not supported on your device." #endif #define QK4_0 32 @@ -66,13 +47,44 @@ struct block_q4_0 }; //------------------------------------------------------------------------------ -// mul_vec_q_n_f32_flat_noshuffle -// -// This variation uses flat arrays (struct of arrays, SOA) representation for -// quant tensors. It also uses non shuffled bit order for weights. -// -// The shuffled version is kept in the original file because moving it here -// seems to result in worse performance for adreno. +// kernel_convert_block_q4_0 +// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q4_0( + global struct block_q4_0 * src0, + global uchar * dst_q, + global half * dst_d +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK4_0/2; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q4_0( + global uchar * src_q, + global half * src_d, + global struct block_q4_0 * dst +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + for (int i = 0; i < QK4_0/2; ++i) { + b->qs[i] = q[i]; + } +} + +//------------------------------------------------------------------------------ +// kernel_convert_block_q4_0_noshuffle +// Flatten q4_0 weights and unshuffle the bits //------------------------------------------------------------------------------ kernel void kernel_convert_block_q4_0_noshuffle( diff --git a/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl b/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl new file mode 100644 index 00000000000..36eff0439fa --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl @@ -0,0 +1,58 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// diag_mask_inf kernels +//------------------------------------------------------------------------------ +kernel void kernel_diag_mask_inf( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int n_past +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i02 = get_global_id(2); + int i01 = get_global_id(1); + int i00 = get_global_id(0); + + if (i00 > n_past + i01) { + dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; + } else { + dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + } +} + +kernel void kernel_diag_mask_inf_8( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + int ne00, + int ne01, + int n_past +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + int i = 2*get_global_id(0); + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int i4 = 4*i; + int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; + int i01 = i4/(ne00); i4 -= i01*ne00; + int i00 = i4; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + (&dst[i+1])[k] = -INFINITY; + if (i00 + k > n_past + i01) { + (&dst[i])[k] = -INFINITY; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/gelu.cl b/ggml/src/ggml-opencl/kernels/gelu.cl new file mode 100644 index 00000000000..71c310cc9f9 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gelu.cl @@ -0,0 +1,62 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// gelu +//------------------------------------------------------------------------------ +#define GELU_COEF_A 0.044715f +#define GELU_QUICK_COEF -1.702f +#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f + +kernel void kernel_gelu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + + dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + + dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_quick( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl similarity index 100% rename from ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl rename to ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl similarity index 100% rename from ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl rename to ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl diff --git a/ggml/src/ggml-opencl/kernels/get_rows.cl b/ggml/src/ggml-opencl/kernels/get_rows.cl new file mode 100644 index 00000000000..b3fea2923df --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/get_rows.cl @@ -0,0 +1,163 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +#define QK4_0 32 + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + + +//------------------------------------------------------------------------------ +// dequantize_q4_0_f32, dequantize_q4_0_f16 +//------------------------------------------------------------------------------ +void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) { + global ushort * qs = ((global ushort *)xb + 1); + float d1 = il ? (xb->d / 16.h) : xb->d; + float d2 = d1 / 256.f; + float md = -8.h * xb->d; + ushort mask0 = il ? 0x00F0 : 0x000F; + ushort mask1 = mask0 << 8; + + reg->s0 = d1 * (qs[0] & mask0) + md; + reg->s1 = d2 * (qs[0] & mask1) + md; + + reg->s2 = d1 * (qs[1] & mask0) + md; + reg->s3 = d2 * (qs[1] & mask1) + md; + + reg->s4 = d1 * (qs[2] & mask0) + md; + reg->s5 = d2 * (qs[2] & mask1) + md; + + reg->s6 = d1 * (qs[3] & mask0) + md; + reg->s7 = d2 * (qs[3] & mask1) + md; + + reg->s8 = d1 * (qs[4] & mask0) + md; + reg->s9 = d2 * (qs[4] & mask1) + md; + + reg->sa = d1 * (qs[5] & mask0) + md; + reg->sb = d2 * (qs[5] & mask1) + md; + + reg->sc = d1 * (qs[6] & mask0) + md; + reg->sd = d2 * (qs[6] & mask1) + md; + + reg->se = d1 * (qs[7] & mask0) + md; + reg->sf = d2 * (qs[7] & mask1) + md; +} + + +//------------------------------------------------------------------------------ +// get_rows +//------------------------------------------------------------------------------ +kernel void kernel_get_rows_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { + ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = + ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { + ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = + ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_q4_0( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + const int NL = 2; + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { + float16 temp; + dequantize_q4_0_f32( + ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp); + *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + } +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl deleted file mode 100644 index b8879288793..00000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl +++ /dev/null @@ -1,3231 +0,0 @@ -#ifdef cl_khr_fp16 -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#elif defined(cl_amd_fp16) -#pragma OPENCL EXTENSION cl_amd_fp16 : enable -#else -#error "Half precision floating point not supportedby OpenCL implementation on your device." -#endif - -#ifdef cl_khr_subgroups -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#elif defined(cl_intel_subgroups) -#pragma OPENCL EXTENSION cl_intel_subgroups : enable -#else -#error "Subgroup not supported on your device." -#endif - -#ifdef cl_intel_required_subgroup_size -// Always use subgroup size of 32 on Intel. -#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable -#define INTEL_GPU 1 -#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) -#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) -#elif defined(cl_qcom_reqd_sub_group_size) -// Always use subgroups size of 64 on Adreno. -#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -#define ADRENO_GPU 1 -#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) -#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#else -// TODO: do not know how to choose subgroup size on other GPUs. -#error "Selecting subgroup size is not supported on your device." -#endif - -#define QK4_0 32 -#define QR4_0 2 -#define QK4_1 32 -#define QR4_1 2 -#define QK5_0 32 -#define QR5_0 2 -#define QK5_1 32 -#define QR5_1 2 -#define QK8_0 32 -#define QR8_0 1 -#define QK_K 256 -#define K_QUANTS_PER_ITERATION 2 - -typedef char int8_t; -typedef uchar uint8_t; -typedef short int16_t; -typedef ushort uint16_t; -typedef int int32_t; -typedef uint uint32_t; - -//------------------------------------------------------------------------------ -// block_q4_0 -//------------------------------------------------------------------------------ -struct block_q4_0 -{ - half d; - uint8_t qs[QK4_0 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q4_1 -//------------------------------------------------------------------------------ -struct block_q4_1 -{ - half d; - half m; - uint8_t qs[QK4_1 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q5_0 -//------------------------------------------------------------------------------ -struct block_q5_0 -{ - half d; - uint32_t qh; - uint8_t qs[QK5_0 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q5_1 -//------------------------------------------------------------------------------ -struct block_q5_1 -{ - half d; - half m; - uint32_t qh; - uint8_t qs[QK5_1 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q8_0 -//------------------------------------------------------------------------------ -struct block_q8_0 -{ - half d; - int8_t qs[QK8_0]; -}; - -//------------------------------------------------------------------------------ -// block_q2_K -//------------------------------------------------------------------------------ -struct block_q2_K -{ - uint8_t scales[16]; - uint8_t qs[64]; - half d; - half dmin; -}; - -//------------------------------------------------------------------------------ -// block_q3_K -//------------------------------------------------------------------------------ -struct block_q3_K -{ - uint8_t hmask[32]; - uint8_t qs[64]; - uint8_t scales[12]; - half d; -}; - -//------------------------------------------------------------------------------ -// block_q4_K -//------------------------------------------------------------------------------ -struct block_q4_K -{ - half d; - half dmin; - uint8_t scales[12]; - uint8_t qs[128]; -}; - -//------------------------------------------------------------------------------ -// block_q5_K -//------------------------------------------------------------------------------ -struct block_q5_K -{ - half d; - half dmin; - uint8_t scales[12]; - uint8_t qh[32]; - uint8_t qs[128]; -}; - -//------------------------------------------------------------------------------ -// block_q6_K -//------------------------------------------------------------------------------ -struct block_q6_K -{ - uint8_t ql[128]; - uint8_t qh[64]; - int8_t scales[16]; - half d; -}; - -//------------------------------------------------------------------------------ -// dequantize_q4_0_f32, dequantize_q4_0_f16 -//------------------------------------------------------------------------------ -void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) { - global ushort * qs = ((global ushort *)xb + 1); - float d1 = il ? (xb->d / 16.h) : xb->d; - float d2 = d1 / 256.f; - float md = -8.h * xb->d; - ushort mask0 = il ? 0x00F0 : 0x000F; - ushort mask1 = mask0 << 8; - - reg->s0 = d1 * (qs[0] & mask0) + md; - reg->s1 = d2 * (qs[0] & mask1) + md; - - reg->s2 = d1 * (qs[1] & mask0) + md; - reg->s3 = d2 * (qs[1] & mask1) + md; - - reg->s4 = d1 * (qs[2] & mask0) + md; - reg->s5 = d2 * (qs[2] & mask1) + md; - - reg->s6 = d1 * (qs[3] & mask0) + md; - reg->s7 = d2 * (qs[3] & mask1) + md; - - reg->s8 = d1 * (qs[4] & mask0) + md; - reg->s9 = d2 * (qs[4] & mask1) + md; - - reg->sa = d1 * (qs[5] & mask0) + md; - reg->sb = d2 * (qs[5] & mask1) + md; - - reg->sc = d1 * (qs[6] & mask0) + md; - reg->sd = d2 * (qs[6] & mask1) + md; - - reg->se = d1 * (qs[7] & mask0) + md; - reg->sf = d2 * (qs[7] & mask1) + md; -} - -void dequantize_q4_0_f16(global struct block_q4_0 * xb, short il, half16 * reg) { - global ushort * qs = ((global ushort *)xb + 1); - half d1 = il ? (xb->d / 16.h) : xb->d; - half d2 = d1 / 256.h; - half md = -8.h * xb->d; - ushort mask0 = il ? 0x00F0 : 0x000F; - ushort mask1 = mask0 << 8; - - reg->s0 = d1 * (qs[0] & mask0) + md; - reg->s1 = d2 * (qs[0] & mask1) + md; - - reg->s2 = d1 * (qs[1] & mask0) + md; - reg->s3 = d2 * (qs[1] & mask1) + md; - - reg->s4 = d1 * (qs[2] & mask0) + md; - reg->s5 = d2 * (qs[2] & mask1) + md; - - reg->s6 = d1 * (qs[3] & mask0) + md; - reg->s7 = d2 * (qs[3] & mask1) + md; - - reg->s8 = d1 * (qs[4] & mask0) + md; - reg->s9 = d2 * (qs[4] & mask1) + md; - - reg->sa = d1 * (qs[5] & mask0) + md; - reg->sb = d2 * (qs[5] & mask1) + md; - - reg->sc = d1 * (qs[6] & mask0) + md; - reg->sd = d2 * (qs[6] & mask1) + md; - - reg->se = d1 * (qs[7] & mask0) + md; - reg->sf = d2 * (qs[7] & mask1) + md; -} - -//------------------------------------------------------------------------------ -// add -//------------------------------------------------------------------------------ - -// general-purpose kernel for addition of two tensors -// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 -// cons: not very efficient -kernel void kernel_add( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global char * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = src0 + offset0; - src1 = src1 + offset1; - dst = dst + offsetd; - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int i13 = i03 % ne13; - int i12 = i02 % ne12; - int i11 = i01 % ne11; - - global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; - - for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { - const int i10 = i0 % ne10; - *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10)); - } -} - -// assumption: src1 is a row -// broadcast src1 into src0 -kernel void kernel_add_row( - global float4 * src0, - ulong offset0, - global float4 * src1, - ulong offset1, - global float4 * dst, - ulong offsetd, - int ne -) { - src0 = (global float4*)((global char*)src0 + offset0); - src1 = (global float4*)((global char*)src1 + offset1); - dst = (global float4*)((global char*)dst + offsetd); - - // This performs better than using %. - uint gid = get_global_id(0); - uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne - dst[gid] = src0[gid] + src1[idx1]; -} - -//------------------------------------------------------------------------------ -// mul -//------------------------------------------------------------------------------ -kernel void kernel_mul( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global char * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = src0 + offset0; - src1 = src1 + offset1; - dst = dst + offsetd; - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int i13 = i03 % ne13; - int i12 = i02 % ne12; - int i11 = i01 % ne11; - - global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; - - for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { - const int i10 = i0 % ne10; - *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10)); - } -} - -// assumption: src1 is a row -// broadcast src1 into src0 -kernel void kernel_mul_row( - global float4 * src0, - ulong offset0, - global float4 * src1, - ulong offset1, - global float4 * dst, - ulong offsetd, - int ne -) { - src0 = (global float4*)((global char*)src0 + offset0); - src1 = (global float4*)((global char*)src1 + offset1); - dst = (global float4*)((global char*)dst + offsetd); - - // This performs better than using %. - uint gid = get_global_id(0); - uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne - dst[gid] = src0[gid] * src1[idx1]; -} - -//------------------------------------------------------------------------------ -// scale -//------------------------------------------------------------------------------ -kernel void kernel_scale( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd, - float scale -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - dst[get_global_id(0)] = src0[get_global_id(0)] * scale; -} - -//------------------------------------------------------------------------------ -// gelu -//------------------------------------------------------------------------------ -#define GELU_COEF_A 0.044715f -#define GELU_QUICK_COEF -1.702f -#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f - -kernel void kernel_gelu( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - float x = src0[get_global_id(0)]; - - dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_4( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - - float4 x = src0[get_global_id(0)]; - - dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_quick( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - float x = src0[get_global_id(0)]; - dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -kernel void kernel_gelu_quick_4( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - - float4 x = src0[get_global_id(0)]; - dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -//------------------------------------------------------------------------------ -// silu -//------------------------------------------------------------------------------ -kernel void kernel_silu( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - float x = src0[get_global_id(0)]; - dst[get_global_id(0)] = x / (1.0f + exp(-x)); -} - -kernel void kernel_silu_4( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - - float4 x = src0[get_global_id(0)]; - dst[get_global_id(0)] = x / (1.0f + exp(-x)); -} - -//------------------------------------------------------------------------------ -// relu -//------------------------------------------------------------------------------ -kernel void kernel_relu( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - dst[get_global_id(0)] = fmax(0.0f, src0[get_global_id(0)]); -} - -//------------------------------------------------------------------------------ -// clamp -//------------------------------------------------------------------------------ -kernel void kernel_clamp( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd, - float min, - float max -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - dst[get_global_id(0)] = src0[get_global_id(0)] < min ? - min : - (src0[get_global_id(0)] > max ? max : src0[get_global_id(0)]); -} - -//------------------------------------------------------------------------------ -// norm -//------------------------------------------------------------------------------ -kernel void kernel_norm( - global void * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb01, - ulong nb02, - ulong nb03, - float eps, - local float * sum -) { - src0 = (global void*)((global char*)src0 + offset0); - dst = (global void*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); - - // MEAN - // parallel sum - sum[get_local_id(0)] = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - sum[get_local_id(0)] += x[i00]; - } - // reduce - barrier(CLK_LOCAL_MEM_FENCE); - for (uint i = get_local_size(0)/2; i > 0; i /= 2) { - if (get_local_id(0) < i) { - sum[get_local_id(0)] += sum[get_local_id(0) + i]; - } - barrier(CLK_LOCAL_MEM_FENCE); - } - float mean = sum[0] / ne00; - - // recenter and VARIANCE - barrier(CLK_LOCAL_MEM_FENCE); - global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - sum[get_local_id(0)] = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - y[i00] = x[i00] - mean; - sum[get_local_id(0)] += y[i00] * y[i00]; - } - - // reduce - barrier(CLK_LOCAL_MEM_FENCE); - for (uint i = get_local_size(0)/2; i > 0; i /= 2) { - if (get_local_id(0) < i) { - sum[get_local_id(0)] += sum[get_local_id(0) + i]; - } - barrier(CLK_LOCAL_MEM_FENCE); - } - float variance = sum[0] / ne00; - - float scale = 1.0f/sqrt(variance + eps); - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - y[i00] = y[i00] * scale; - } -} - -//------------------------------------------------------------------------------ -// rms_norm -//------------------------------------------------------------------------------ -// This kernel depends on subgroup size. -kernel void kernel_rms_norm( - global void * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb01, - ulong nb02, - ulong nb03, - float eps, - local float * sum // Note, the size depends on number of subgroups -) { - src0 = (global void*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); - global float * x_scalar = (global float *) x; - float4 sumf = 0; - float all_sum = 0; - - // parallel sum - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - sumf += x[i00] * x[i00]; - } - all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3; - all_sum = sub_group_reduce_add(all_sum); - if (get_sub_group_local_id() == 0) { - sum[get_sub_group_id()] = all_sum; - } - - barrier(CLK_LOCAL_MEM_FENCE); - // broadcast - for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { - if (get_local_id(0) < i) { - sum[get_local_id(0)] += sum[get_local_id(0) + i]; - } - } - if (get_local_id(0) == 0) { - for (int i = 4 * (ne00 / 4); i < ne00; i++) { - sum[0] += x_scalar[i]; - } - sum[0] /= ne00; - } - - barrier(CLK_LOCAL_MEM_FENCE); - - const float mean = sum[0]; - const float scale = 1.0f/sqrt(mean + eps); - - global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - global float * y_scalar = (global float *) y; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - y[i00] = x[i00] * scale; - } - if (get_local_id(0) == 0) { - for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { - y_scalar[i00] = x_scalar[i00] * scale; - } - } -} - -//------------------------------------------------------------------------------ -// diag_mask_inf kernels -//------------------------------------------------------------------------------ -kernel void kernel_diag_mask_inf( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int n_past -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - int i02 = get_global_id(2); - int i01 = get_global_id(1); - int i00 = get_global_id(0); - - if (i00 > n_past + i01) { - dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; - } else { - dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; - } -} - -kernel void kernel_diag_mask_inf_8( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd, - int ne00, - int ne01, - int n_past -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - - int i = 2*get_global_id(0); - - dst[i+0] = src0[i+0]; - dst[i+1] = src0[i+1]; - int i4 = 4*i; - int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; - int i01 = i4/(ne00); i4 -= i01*ne00; - int i00 = i4; - for (int k = 3; k >= 0; --k) { - if (i00 + 4 + k <= n_past + i01) { - break; - } - (&dst[i+1])[k] = -INFINITY; - if (i00 + k > n_past + i01) { - (&dst[i])[k] = -INFINITY; - } - } -} - -//------------------------------------------------------------------------------ -// softmax -//------------------------------------------------------------------------------ -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_soft_max( - global float * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - float scale, - float max_bias, - float m0, - float m1, - int n_head_log2 -) { - src0 = (global float*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0; - global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - int h = i02; - - float base = h < n_head_log2 ? m0 : m1; - int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float lmax = -INFINITY; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); - } - float max = sub_group_reduce_max(lmax); - - // parallel sum - float lsum = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); - lsum += exp_psrc0; - // Remember the result of exp here. exp is expensive, so we really do not - // wish to compute it twice. - pdst[i00] = exp_psrc0; - } - - const float sum = sub_group_reduce_add(lsum); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - pdst[i00] /= sum; - } -} - -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_soft_max_4( - global float * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - float scale, - float max_bias, - float m0, - float m1, - int n_head_log2 -) { - src0 = (global float*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0; - global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - int h = i02; - - float base = h < n_head_log2 ? m0 : m1; - int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float4 lmax4 = -INFINITY; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); - } - float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); - - const float max = sub_group_reduce_max(lmax); - - // parallel sum - float4 lsum4 = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); - lsum4 += exp_psrc4; - pdst4[i00] = exp_psrc4; - } - float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; - - const float sum = sub_group_reduce_add(lsum); - - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - pdst4[i00] /= sum; - } -} - -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_soft_max_f16( - global float * src0, - ulong offset0, - global half * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - float scale, - float max_bias, - float m0, - float m1, - int n_head_log2 -) { - src0 = (global float *)((global char *)src0 + offset0); - src1 = (global half *)((global char *)src1 + offset1); - dst = (global float *)((global char *)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0; - global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - int h = i02; - - float base = h < n_head_log2 ? m0 : m1; - int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float lmax = -INFINITY; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); - } - float max = sub_group_reduce_max(lmax); - - // parallel sum - float lsum = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); - lsum += exp_psrc0; - // Remember the result of exp here. exp is expensive, so we really do not - // wish to compute it twice. - pdst[i00] = exp_psrc0; - } - - const float sum = sub_group_reduce_add(lsum); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - pdst[i00] /= sum; - } -} - -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_soft_max_4_f16( - global float * src0, - ulong offset0, - global half * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - float scale, - float max_bias, - float m0, - float m1, - int n_head_log2 -) { - src0 = (global float *)((global char *)src0 + offset0); - src1 = (global half *)((global char *)src1 + offset1); - dst = (global float *)((global char *)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0; - global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - int h = i02; - - float base = h < n_head_log2 ? m0 : m1; - int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float4 lmax4 = -INFINITY; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)); - } - float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); - - const float max = sub_group_reduce_max(lmax); - - // parallel sum - float4 lsum4 = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)) - max); - lsum4 += exp_psrc4; - pdst4[i00] = exp_psrc4; - } - float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; - - const float sum = sub_group_reduce_add(lsum); - - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - pdst4[i00] /= sum; - } -} - -//------------------------------------------------------------------------------ -// kernel_rope -//------------------------------------------------------------------------------ -float rope_yarn_ramp(float low, float high, int i0) { - const float y = (i0 / 2 - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); -} - -// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn -// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -float2 rope_yarn( - float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale -) { - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = freq_scale * theta_extrap; - float theta = theta_interp; - if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor; - theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - - // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); - } - return (float2)(cos(theta) * mscale, sin(theta) * mscale); -} - -// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get -// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { - return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); -} - -float2 rope_yarn_corr_dims( - int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow -) { - // start and end correction dims - return (float2)( - max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))), - min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))) - ); -} - -kernel void kernel_rope_norm_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - float theta_base = (float) pos[i2]; - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - float theta = theta_base * pow(freq_base, inv_ndims*i0); - - float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - float x0 = src[0]; - float x1 = src[1]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_norm_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - float theta_base = (float) pos[i2]; - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - float theta = theta_base * pow(freq_base, inv_ndims*i0); - - float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - float x0 = src[0]; - float x1 = src[1]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_neox_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - float theta_base = (float) pos[i2]; - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_neox_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - float theta_base = (float) pos[i2]; - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_multi_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow, - int4 sections -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; - const int sec_w = sections.s1 + sections.s0; - - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - const int sector = (i0 / 2) % sect_dims; - float theta_base = 0.0f; - - if (sector < sections.s0) { - theta_base = pos[i2]; - } - else if (sector >= sections.s0 && sector < sec_w) { - theta_base = pos[i2 + ne2 * 1]; - } - else if (sector >= sec_w && sector < sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 2]; - } - else if (sector >= sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 3]; - } - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_multi_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global half * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow, - int4 sections -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; - const int sec_w = sections.s1 + sections.s0; - - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - const int sector = (i0 / 2) % sect_dims; - float theta_base = 0.0f; - - if (sector < sections.s0) { - theta_base = pos[i2]; - } - else if (sector >= sections.s0 && sector < sec_w) { - theta_base = pos[i2 + ne2 * 1]; - } - else if (sector >= sec_w && sector < sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 2]; - } - else if (sector >= sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 3]; - } - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_vision_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow, - int4 sections -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - const int sect_dims = sections.s0 + sections.s1; - const int sec_w = sections.s1 + sections.s0; - - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - int ic = i0/2; - - const int sector = (i0/2) % sect_dims; - float theta_base = 0.0f; - - if (sector < sections.s0) { - const int p = sector; - theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); - } else if (sector >= sections.s0 && sector < sec_w) { - const int p = sector - sections.s0; - theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); - } - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } -} - -kernel void kernel_rope_vision_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global half * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow, - int4 sections -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - const int sect_dims = sections.s0 + sections.s1; - const int sec_w = sections.s1 + sections.s0; - - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - int ic = i0/2; - - const int sector = (i0/2) % sect_dims; - float theta_base = 0.0f; - - if (sector < sections.s0) { - const int p = sector; - theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); - } else if (sector >= sections.s0 && sector < sec_w) { - const int p = sector - sections.s0; - theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); - } - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } -} - -//------------------------------------------------------------------------------ -// cpy -//------------------------------------------------------------------------------ - -kernel void kernel_cpy_f16_f16( - global half * src0, - ulong offset0, - global half * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = (global half*)((global char*)src0 + offset0); - dst = (global half*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - int i3 = n / (ne2*ne1*ne0); - int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f16_f32( - global half * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - - src0 = (global half*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - int i3 = n / (ne2*ne1*ne0); - int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f16( - global float * src0, - ulong offset0, - global half * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global half*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - int i3 = n / (ne2*ne1*ne0); - int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f32( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - int i3 = n / (ne2*ne1*ne0); - int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -//------------------------------------------------------------------------------ -// get_rows -//------------------------------------------------------------------------------ -kernel void kernel_get_rows_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - ulong nb01, - ulong nb02, - int ne10, - ulong nb10, - ulong nb11, - ulong nb1, - ulong nb2 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int i10 = get_group_id(0); - int i11 = get_group_id(1); - - int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; - - int i02 = i11; - - for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = - ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - -kernel void kernel_get_rows_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - ulong nb01, - ulong nb02, - int ne10, - ulong nb10, - ulong nb11, - ulong nb1, - ulong nb2 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int i10 = get_group_id(0); - int i11 = get_group_id(1); - - int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; - - int i02 = i11; - - for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = - ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - -kernel void kernel_get_rows_q4_0( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - ulong nb01, - ulong nb02, - int ne10, - ulong nb10, - ulong nb11, - ulong nb1, - ulong nb2 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - const int NL = 2; - - int i10 = get_group_id(0); - int i11 = get_group_id(1); - - int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; - - int i02 = i11; - - for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { - float16 temp; - dequantize_q4_0_f32( - ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp); - *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; - } -} - -//------------------------------------------------------------------------------ -// mul_mat_f32_f32 -//------------------------------------------------------------------------------ -#define N_F32_F32 4 - -kernel void kernel_mul_mat_f32_f32( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int r0 = get_group_id(0); - int rb = get_group_id(1)*N_F32_F32; - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - - global float * x = (global float *) (src0 + offset_src0); - - if (ne00 < 128) { - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float * y = (global float *) (src1 + offset_src1); - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { - sumf += (float) x[i] * (float) y[i]; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - global float4 * x4 = (global float4 *)x; - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float * y = (global float *) (src1 + offset_src1); - global float4 * y4 = (global float4 *) y; - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += (float) x4[i].s0 * y4[i].s0; - sumf += (float) x4[i].s1 * y4[i].s1; - sumf += (float) x4[i].s2 * y4[i].s2; - sumf += (float) x4[i].s3 * y4[i].s3; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) { - all_sum += (float) x[i] * y[i]; - } - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -//------------------------------------------------------------------------------ -// mul_mat_f16_f16 -//------------------------------------------------------------------------------ -#define N_F16_F16 4 - -kernel void kernel_mul_mat_f16_f16( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3) -{ - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int r0 = get_group_id(0); - int rb = get_group_id(1)*N_F16_F16; - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - - global half * x = (global half *) (src0 + offset_src0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global half * y = (global half *) (src1 + offset_src1); - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { - sumf += (half) x[i] * (half) y[i]; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - global half4 * x4 = (global half4 *)x; - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global half * y = (global half *) (src1 + offset_src1); - global half4 * y4 = (global half4 *) y; - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += (half) x4[i].s0 * y4[i].s0; - sumf += (half) x4[i].s1 * y4[i].s1; - sumf += (half) x4[i].s2 * y4[i].s2; - sumf += (half) x4[i].s3 * y4[i].s3; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) { - all_sum += (half) x[i] * y[i]; - } - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -//------------------------------------------------------------------------------ -// mul_mat_f16_f32_1row -//------------------------------------------------------------------------------ -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_f16_f32_1row( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global half * x = (global half *) (src0 + offset_src0); - global float * y = (global float *) (src1 + offset_src1); - - float sumf = 0; - if (ne00 < 128) { - for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { - sumf += (float) x[i] * (float) y[i]; - } - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } else { - global half4 * x4 = (global half4 *) x; - global float4 * y4 = (global float4 *) y; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += (float) x4[i].s0 * y4[i].s0; - sumf += (float) x4[i].s1 * y4[i].s1; - sumf += (float) x4[i].s2 * y4[i].s2; - sumf += (float) x4[i].s3 * y4[i].s3; - } - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) { - all_sum += (float) x[i] * y[i]; - } - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - -} - -//------------------------------------------------------------------------------ -// mul_mat_f16_f32 -//------------------------------------------------------------------------------ -#define N_F16_F32 4 - -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_f16_f32( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int r0 = get_group_id(0); - int rb = get_group_id(1)*N_F16_F32; - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - - global half * x = (global half *) (src0 + offset_src0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float * y = (global float *) (src1 + offset_src1); - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { - sumf += convert_float(x[i]) * y[i]; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - global half4 * x4 = (global half4 *)x; - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float * y = (global float *) (src1 + offset_src1); - global float4 * y4 = (global float4 *) y; - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += convert_float(x4[i].s0) * y4[i].s0; - sumf += convert_float(x4[i].s1) * y4[i].s1; - sumf += convert_float(x4[i].s2) * y4[i].s2; - sumf += convert_float(x4[i].s3) * y4[i].s3; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) { - all_sum += (float) x[i] * y[i]; - } - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -//------------------------------------------------------------------------------ -// mul_mat_f16_f32_l4 -//------------------------------------------------------------------------------ -// Assumes row size (ne00) is a multiple of 4 -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_f16_f32_l4( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int nrows = ne11; - int r0 = get_group_id(0); - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - - global half4 * x4 = (global half4 *) (src0 + offset_src0); - - for (int r1 = 0; r1 < nrows; ++r1) { - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float4 * y4 = (global float4 *) (src1 + offset_src1); - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += convert_float(x4[i].s0) * y4[i].s0; - sumf += convert_float(x4[i].s1) * y4[i].s1; - sumf += convert_float(x4[i].s2) * y4[i].s2; - sumf += convert_float(x4[i].s3) * y4[i].s3; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } -} - -//------------------------------------------------------------------------------ -// mul_vec_q_n_f32 -//------------------------------------------------------------------------------ -// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_4_0_dot_y( - global struct block_q4_0 * qb_curr, - float sumy, - private float * yl, - int il -) { - float d = qb_curr->d; - float2 acc = 0.f; - global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); - for (int i = 0; i < 8; i+=2) { - acc.s0 += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc.s1 += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (sumy * -8.f + acc.s0 + acc.s1); -} - -#ifdef INTEL_GPU -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 -#elif defined (ADRENO_GPU) -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -inline void mul_vec_q_n_f32( - global void * src0, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - - const ulong nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global - // id of a SIMD group in the grid. - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[16]; // src1 vector cache - float sumf[N_DST]={0.f}; - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix * QK4_0 + il; - - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0; - for (int i = 0; i < 8; i += 2) { - sumy += yb[i] + yb[i+1]; - yl[i+0] = yb[i+ 0]; - yl[i+1] = yb[i+ 1]/256.f; - sumy += yb[i+16] + yb[i+17]; - yl[i+8] = yb[i+16]/16.f; - yl[i+9] = yb[i+17]/4096.f; - } - - for (int row = 0; row < N_DST; row++) { - sumf[row] += block_q_4_0_dot_y(x+ib+row*nb, sumy, yl, il); - } - - // One thread in a SIMD group (i.e., subgroup) handles a half block, - // hence then entire SIMD group handles SIMDWIDTH/2 blocks. - // y points to the activation matrix (of type float). Therefore for - // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because - // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of - // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - // The above does not work for Adreno - it produces incorrect results for - // row = 1, 2, 3 and only row = 0 gives the correct result. - // If N_DST is changed, the below array must be initialized accordingly. - // This also seems to perform better on Intel. - float tot[N_DST] = { - sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]), - sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])}; - for (int row = 0; row < N_DST; ++row) { - if (get_sub_group_local_id() == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row]; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32( - global void * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -// -// This variant unrolls the loops and uses vector types instead of pointers. -// It improves performance on Adreno but not so much on Intel. -// -inline float block_q_4_0_dot_y_v( - global struct block_q4_0 * qb_curr, - float sumy, - float16 yl, - int il -) { - float d = qb_curr->d; - float acc = 0.f; - global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); - - acc += yl.s0 * (qs[0] & 0x000F); - acc += yl.s1 * (qs[0] & 0x0F00); - acc += yl.s8 * (qs[0] & 0x00F0); - acc += yl.s9 * (qs[0] & 0xF000); - - acc += yl.s2 * (qs[1] & 0x000F); - acc += yl.s3 * (qs[1] & 0x0F00); - acc += yl.sa * (qs[1] & 0x00F0); - acc += yl.sb * (qs[1] & 0xF000); - - acc += yl.s4 * (qs[2] & 0x000F); - acc += yl.s5 * (qs[2] & 0x0F00); - acc += yl.sc * (qs[2] & 0x00F0); - acc += yl.sd * (qs[2] & 0xF000); - - acc += yl.s6 * (qs[3] & 0x000F); - acc += yl.s7 * (qs[3] & 0x0F00); - acc += yl.se * (qs[3] & 0x00F0); - acc += yl.sf * (qs[3] & 0xF000); - - return d * (sumy * -8.f + acc); -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 -#elif defined (ADRENO_GPU) -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -inline void mul_vec_q_n_f32_v( - global void * src0, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const ulong nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global - // id of a SIMD group in the grid. - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; // src1 vector cache - float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix * QK4_0 + il; - - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += block_q_4_0_dot_y_v(x+ib+0*nb, sumy, yl, il); - sumf.s1 += block_q_4_0_dot_y_v(x+ib+1*nb, sumy, yl, il); - sumf.s2 += block_q_4_0_dot_y_v(x+ib+2*nb, sumy, yl, il); - sumf.s3 += block_q_4_0_dot_y_v(x+ib+3*nb, sumy, yl, il); - - // One thread in a SIMD group (i.e., subgroup) handles a half block, - // hence then entire SIMD group handles SIMDWIDTH/2 blocks. - // y points to the activation matrix (of type float). Therefore for - // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because - // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of - // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - // The above does not work for Adreno - it produces incorrect results for - // row = 1, 2, 3 and only row = 0 gives the correct result. - // If N_DST is changed, the below array must be initialized accordingly. - // This also seems to perform better on Intel. - float4 tot = (float4)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_v( - global void * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_vec_q_n_f32_v(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -//------------------------------------------------------------------------------ -// kernel_convert_block_q4_0 -// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). -// This kernel does not deshuffle the bits. -//------------------------------------------------------------------------------ -kernel void kernel_convert_block_q4_0( - global struct block_q4_0 * src0, - global uchar * dst_q, - global half * dst_d -) { - global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0); - global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0); - global half * d = (global half *) dst_d + get_global_id(0); - - *d = b->d; - - for (int i = 0; i < QK4_0/2; ++i) { - q[i] = b->qs[i]; - } -} - -kernel void kernel_restore_block_q4_0( - global uchar * src_q, - global half * src_d, - global struct block_q4_0 * dst -) { - global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0); - global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0); - global half * d = (global half *) src_d + get_global_id(0); - - b->d = *d; - for (int i = 0; i < QK4_0/2; ++i) { - b->qs[i] = q[i]; - } -} - -//------------------------------------------------------------------------------ -// mul_vec_q_n_f32_flat -// -// This variation uses flat arrays (struct of arrays, SOA) representation for -// quant tensors. -//------------------------------------------------------------------------------ - -// This function requires the original shuffled weights. -// As a reminder, the original weights are shuffled so that (q[0], q[16]) are -// packed together in a byte, so are (q[1], q[17]) and so on. -inline float block_q_4_0_dot_y_flat( - global uchar * x, - global half * dh, - float sumy, - float16 yl, - int il -) { - float d = *dh; - global ushort * qs = ((global ushort *)x + il/2); - float acc = 0.f; - - acc += yl.s0 * (qs[0] & 0x000F); - acc += yl.s1 * (qs[0] & 0x0F00); - acc += yl.s8 * (qs[0] & 0x00F0); - acc += yl.s9 * (qs[0] & 0xF000); - - acc += yl.s2 * (qs[1] & 0x000F); - acc += yl.s3 * (qs[1] & 0x0F00); - acc += yl.sa * (qs[1] & 0x00F0); - acc += yl.sb * (qs[1] & 0xF000); - - acc += yl.s4 * (qs[2] & 0x000F); - acc += yl.s5 * (qs[2] & 0x0F00); - acc += yl.sc * (qs[2] & 0x00F0); - acc += yl.sd * (qs[2] & 0xF000); - - acc += yl.s6 * (qs[3] & 0x000F); - acc += yl.s7 * (qs[3] & 0x0F00); - acc += yl.se * (qs[3] & 0x00F0); - acc += yl.sf * (qs[3] & 0xF000); - - return d * (sumy * -8.f + acc); -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 32 -#elif defined (ADRENO_GPU) -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -inline void mul_vec_q_n_f32_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const ulong nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of - // a SIMD group in the grid. Each SIMD group produces N_DST values in the - // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. - // Currently with llama2 7B, im is always 0. - // TODO: how to handle im/gqa*(nb*ne0)? - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; - float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix*QK4_0 + il; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0.f; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); - sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); - sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); - sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); - - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - float4 tot = (float4)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_vec_q_n_f32_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -// -// This variant outputs 8 values. -// -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 8 // each SIMD group works on 8 rows -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 32 -#elif defined (ADRENO_GPU) -#define N_DST 8 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -inline void mul_vec_q_n_f32_8x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const ulong nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of - // a SIMD group in the grid. Each SIMD group produces N_DST values in the - // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. - // Currently with llama2 7B, im is always 0. - // TODO: how to handle im/gqa*(nb*ne0)? - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; - float8 sumf = 0.f; - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix*QK4_0 + il; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0.f; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); - sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); - sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); - sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); - - sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); - sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); - sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); - sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); - - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - float8 tot = (float8)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), - sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), - sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_8x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl deleted file mode 100644 index 9b41dfb2555..00000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl +++ /dev/null @@ -1,146 +0,0 @@ -#ifdef cl_khr_fp16 -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#elif defined(cl_amd_fp16) -#pragma OPENCL EXTENSION cl_amd_fp16 : enable -#else -#error "Half precision floating point not supportedby OpenCL implementation on your device." -#endif - -#ifdef cl_khr_subgroups -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#elif defined(cl_intel_subgroups) -#pragma OPENCL EXTENSION cl_intel_subgroups : enable -#else -#error "Subgroup not supported on your device." -#endif - -#ifdef cl_intel_required_subgroup_size -// Always use subgroup size of 32 on Intel. -#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable -#define INTEL_GPU 1 -#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) -#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) -#elif defined(cl_qcom_reqd_sub_group_size) -// Always use subgroups size of 64 on Adreno. -#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -#define ADRENO_GPU 1 -#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) -#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#else -// TODO: do not know how to choose subgroup size on other GPUs. -#error "Selecting subgroup size is not supported on your device." -#endif - -kernel void kernel_im2col_f32( - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - ulong batch_offset, - ulong delta_offset, - long IW, - long IH, - long IC, - long OW, - long OH, - long KW, - long KH, - long pelements, - long CHW, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1 -) { - // threadIdx.x + blockIdx.x * blockDim.x - long i = get_global_id(0); - if (i >= pelements) { - return; - } - - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - long ksize = OW * (KH > 1 ? KW : 1); - long kx = i / ksize; - long kd = kx * ksize; - long ky = (i - kd) / OW; - long ix = i % OW; - - long oh = get_group_id(1); - long batch = get_group_id(2) / IC; - long ic = get_group_id(2) % IC; - - long iiw = ix * s0 + kx * d0 - p0; - long iih = oh * s1 + ky * d1 - p1; - - long offset_dst = - ((batch * OH + oh) * OW + ix) * CHW + - (ic * (KW * KH) + ky * KW + kx); - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; - } else { - long offset_src = ic * delta_offset + batch * batch_offset; - dst[offset_dst] = src1[offset_src + iih * IW + iiw]; - } -} - -kernel void kernel_im2col_f16( - global float * src1, - ulong offset1, - global half * dst, - ulong offsetd, - ulong batch_offset, - ulong delta_offset, - long IW, - long IH, - long IC, - long OW, - long OH, - long KW, - long KH, - long pelements, - long CHW, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1 -) { - long i = get_global_id(0); - - if (i >= pelements) { - return; - } - - src1 = (global float*)((global char*)src1 + offset1); - dst = (global half*)((global char*)dst + offsetd); - - long ksize = OW * (KH > 1 ? KW : 1); - long kx = i / ksize; - long kd = kx * ksize; - long ky = (i - kd) / OW; - long ix = i % OW; - - long oh = get_group_id(1); - long batch = get_group_id(2) / IC; - long ic = get_group_id(2) % IC; - - long iiw = ix * s0 + kx * d0 - p0; - long iih = oh * s1 + ky * d1 - p1; - - long offset_dst = - ((batch * OH + oh) * OW + ix) * CHW + - (ic * (KW * KH) + ky * KW + kx); - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; - } else { - long offset_src = ic * delta_offset + batch * batch_offset; - dst[offset_dst] = src1[offset_src + iih * IW + iiw]; - } -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl deleted file mode 100644 index e19e9a2f436..00000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl +++ /dev/null @@ -1,1225 +0,0 @@ -//------------------------------------------------------------------------------ -// This file is contains additional mulmat kernels -// (and potentially other kernels). -//------------------------------------------------------------------------------ -#ifdef cl_khr_fp16 -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#elif defined(cl_amd_fp16) -#pragma OPENCL EXTENSION cl_amd_fp16 : enable -#else -#error "Half precision floating point not supportedby OpenCL implementation on your device." -#endif - -#ifdef cl_khr_subgroups -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#elif defined(cl_intel_subgroups) -#pragma OPENCL EXTENSION cl_intel_subgroups : enable -#else -#error "Subgroup not supported on your device." -#endif - -#ifdef cl_intel_required_subgroup_size -// Always use subgroup size of 32 on Intel. -#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable -#define INTEL_GPU 1 -#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) -#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) -#elif defined(cl_qcom_reqd_sub_group_size) -// Always use subgroups size of 64 on Adreno. -#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -#define ADRENO_GPU 1 -#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) -#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#else -// TODO: do not know how to choose subgroup size on other GPUs. -#error "Selecting subgroup size is not supported on your device." -#endif - -#define QK4_0 32 -#define QR4_0 2 -#define QK4_1 32 -#define QR4_1 2 -#define QK5_0 32 -#define QR5_0 2 -#define QK5_1 32 -#define QR5_1 2 -#define QK8_0 32 -#define QR8_0 1 -#define QK_K 256 -#define K_QUANTS_PER_ITERATION 2 - -typedef char int8_t; -typedef uchar uint8_t; -typedef short int16_t; -typedef ushort uint16_t; -typedef int int32_t; -typedef uint uint32_t; - -//------------------------------------------------------------------------------ -// block_q4_0 -//------------------------------------------------------------------------------ -struct block_q4_0 -{ - half d; - uint8_t qs[QK4_0 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q6_K -//------------------------------------------------------------------------------ -// 6-bit quantization -// weight is represented as x = a * q -// 16 blocks of 16 elements each -// Effectively 6.5625 bits per weight -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales, quantized with 8 bits - half d; // super-block scale -} block_q6_K; - -//------------------------------------------------------------------------------ -// These are the variant for matmatmul, based on the matvecmul kernel with -// flattened block_q4_0. -//------------------------------------------------------------------------------ - -// Common dot prod. -inline float mm_block_q_4_0_dot_y_flat( - global uchar * x, - global half * dh, - float sumy, - float16 yl, - int il -) { - float d = *dh; - global ushort * qs = ((global ushort *)x + il/2); - float acc = 0.f; - - acc += yl.s0 * (qs[0] & 0x000F); - acc += yl.s1 * (qs[0] & 0x0F00); - acc += yl.s8 * (qs[0] & 0x00F0); - acc += yl.s9 * (qs[0] & 0xF000); - - acc += yl.s2 * (qs[1] & 0x000F); - acc += yl.s3 * (qs[1] & 0x0F00); - acc += yl.sa * (qs[1] & 0x00F0); - acc += yl.sb * (qs[1] & 0xF000); - - acc += yl.s4 * (qs[2] & 0x000F); - acc += yl.s5 * (qs[2] & 0x0F00); - acc += yl.sc * (qs[2] & 0x00F0); - acc += yl.sd * (qs[2] & 0xF000); - - acc += yl.s6 * (qs[3] & 0x000F); - acc += yl.s7 * (qs[3] & 0x0F00); - acc += yl.se * (qs[3] & 0x00F0); - acc += yl.sf * (qs[3] & 0xF000); - - return d * (sumy * -8.f + acc); -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 8 // each SIMD group works on 8 rows (in weights matrix) -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 -#elif defined (ADRENO_GPU) -#define N_DST 8 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif -// -// This variant performs 1d blocking with 8x output. -// Eeach simdgroup outputs 8 values on `n0` dim (row in the output matrix). -// -inline void mul_mat_q_n_f32_1d_8x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const int nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of - // a SIMD group in the grid. Each SIMD group produces N_DST values in the - // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. - // Currently with llama2 7B, im is always 0. - // TODO: how to handle im/gqa*(nb*ne0)? - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; - float8 sumf = (float8)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix*QK4_0 + il; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0.f; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); - sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); - sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); - sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); - - sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); - sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); - sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); - sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); - - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - float8 tot = (float8)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), - sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), - sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_1d_8x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_mat_q_n_f32_1d_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 16 // each SIMD group works on 8 rows (in weights matrix) -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 -#elif defined (ADRENO_GPU) -#define N_DST 16 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif -// -// This variant performs 1d blocking with 16x output. -// Eeach simdgroup outputs 16 values on `n0` dim (row in the output matrix). -// -inline void mul_mat_q_n_f32_1d_16x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const int nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of - // a SIMD group in the grid. Each SIMD group produces N_DST values in the - // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. - // Currently with llama2 7B, im is always 0. - // TODO: how to handle im/gqa*(nb*ne0)? - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; - float16 sumf = (float16)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix*QK4_0 + il; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0.f; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); - sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); - sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); - sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); - - sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); - sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); - sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); - sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); - - sumf.s8 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 8*nb*QK4_0/2, d + ib + 8*nb, sumy, yl, il); - sumf.s9 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 9*nb*QK4_0/2, d + ib + 9*nb, sumy, yl, il); - sumf.sa += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 10*nb*QK4_0/2, d + ib + 10*nb, sumy, yl, il); - sumf.sb += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 11*nb*QK4_0/2, d + ib + 11*nb, sumy, yl, il); - - sumf.sc += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 12*nb*QK4_0/2, d + ib + 12*nb, sumy, yl, il); - sumf.sd += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 13*nb*QK4_0/2, d + ib + 13*nb, sumy, yl, il); - sumf.se += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 14*nb*QK4_0/2, d + ib + 14*nb, sumy, yl, il); - sumf.sf += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 15*nb*QK4_0/2, d + ib + 15*nb, sumy, yl, il); - - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - float16 tot = (float16)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), - sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), - sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7), - - sub_group_reduce_add(sumf.s8), sub_group_reduce_add(sumf.s9), - sub_group_reduce_add(sumf.sa), sub_group_reduce_add(sumf.sb), - sub_group_reduce_add(sumf.sc), sub_group_reduce_add(sumf.sd), - sub_group_reduce_add(sumf.se), sub_group_reduce_add(sumf.sf) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } - - if (first_row + 8 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 8] = tot.s8; - } - if (first_row + 9 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 9] = tot.s9; - } - if (first_row + 10 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 10] = tot.sa; - } - if (first_row + 11 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 11] = tot.sb; - } - - if (first_row + 12 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 12] = tot.sc; - } - if (first_row + 13 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 13] = tot.sd; - } - if (first_row + 14 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 14] = tot.se; - } - if (first_row + 15 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 15] = tot.sf; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_1d_16x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_mat_q_n_f32_1d_16x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -//------------------------------------------------------------------------------ -// kernel_mul_mat_q4_0_f32_flat_v0 -//------------------------------------------------------------------------------ -inline float block_q_4_0_dot_y_flat_v2( - half x, - half d, - float sumy, - float4 yl -) { - uchar2 q = as_uchar2(x); - float acc = 0.0f; - - acc += (q.s0 & 0x0F) * yl.s0; - acc += (q.s1 & 0x0F) * yl.s1; - - acc += (q.s0 & 0xF0) * yl.s2; - acc += (q.s1 & 0xF0) * yl.s3; - - return d * (sumy * -8.f + acc);; -} - -inline float block_q_4_0_dot_y_flat_v4( - float x, - half d, - float sumy, - float8 yl -) { - uchar4 q = as_uchar4(x); - float acc = 0.0f; - - acc += (q.s0 & 0x0F) * yl.s0; - acc += (q.s1 & 0x0F) * yl.s1; - acc += (q.s2 & 0x0F) * yl.s2; - acc += (q.s3 & 0x0F) * yl.s3; - - acc += (q.s0 & 0xF0) * yl.s4; - acc += (q.s1 & 0xF0) * yl.s5; - acc += (q.s2 & 0xF0) * yl.s6; - acc += (q.s3 & 0xF0) * yl.s7; - - return d * (sumy * -8.f + acc);; -} - -inline float block_q_4_0_dot_y_flat_v8( - float2 x, - half d, - float sumy, - float16 yl -) { - uchar8 q = as_uchar8(x); - float acc = 0.0f; - - acc += (q.s0 & 0x0F) * yl.s0; - acc += (q.s1 & 0x0F) * yl.s1; - acc += (q.s2 & 0x0F) * yl.s2; - acc += (q.s3 & 0x0F) * yl.s3; - acc += (q.s4 & 0x0F) * yl.s4; - acc += (q.s5 & 0x0F) * yl.s5; - acc += (q.s6 & 0x0F) * yl.s6; - acc += (q.s7 & 0x0F) * yl.s7; - - acc += (q.s0 & 0xF0) * yl.s8; - acc += (q.s1 & 0xF0) * yl.s9; - acc += (q.s2 & 0xF0) * yl.sa; - acc += (q.s3 & 0xF0) * yl.sb; - acc += (q.s4 & 0xF0) * yl.sc; - acc += (q.s5 & 0xF0) * yl.sd; - acc += (q.s6 & 0xF0) * yl.se; - acc += (q.s7 & 0xF0) * yl.sf; - - return d * (sumy * -8.f + acc);; -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define THREADS_PER_BLK 4 // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 16 -#elif defined (ADRENO_GPU) -#define THREADS_PER_BLK 4 -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block -# define ACT_TY float16 -# define Q_BLK_LD_TY float2 -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v8 -#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block -# define ACT_TY float8 -# define Q_BLK_LD_TY float -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v4 -#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block -# define ACT_TY float4 -# define Q_BLK_LD_TY half -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v2 -#endif - -#define BTYES_PER_THREAD_IN_BLK (QK4_0/2/THREADS_PER_BLK) - -#if N_DST == 2 -# define SUM_TY float2 -#elif N_DST == 4 -# define SUM_TY float4 -#elif N_DST == 8 -# define SUM_TY float8 -#elif N_DST == 16 -# define SUM_TY float16 -#endif - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_flat_v0( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - const int nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - int ix = get_sub_group_local_id()/THREADS_PER_BLK; - int il = get_sub_group_local_id()%THREADS_PER_BLK; - - global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il; - - // Registers for caching activation - ACT_TY yl = 0.f; - - // Registers for caching quants - Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0; -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0; -#endif -#if N_DST == 8 || N_DST == 16 - Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0; -#endif - - // Partial sum - SUM_TY sumf = 0.f; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) { - float sumy = 0.f; - - q_blk_0 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 0*nb*QK4_0/2); - q_blk_1 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 1*nb*QK4_0/2); -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - q_blk_2 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 2*nb*QK4_0/2); - q_blk_3 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 3*nb*QK4_0/2); -#endif -#if N_DST == 8 || N_DST == 16 - q_blk_4 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 4*nb*QK4_0/2)); - q_blk_5 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 5*nb*QK4_0/2)); - q_blk_6 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 6*nb*QK4_0/2)); - q_blk_7 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 7*nb*QK4_0/2)); -#endif - - // Load activation -#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block - yl.s01234567 = *(global float8 *)(yb); - yl.s89abcdef = *(global float8 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; - sumy += yl.s3; - sumy += yl.s4; - sumy += yl.s5; - sumy += yl.s6; - sumy += yl.s7; - sumy += yl.s8; yl.s8 /= 16.f; - sumy += yl.s9; yl.s9 /= 16.f; - sumy += yl.sa; yl.sa /= 16.f; - sumy += yl.sb; yl.sb /= 16.f; - sumy += yl.sc; yl.sc /= 16.f; - sumy += yl.sd; yl.sd /= 16.f; - sumy += yl.se; yl.se /= 16.f; - sumy += yl.sf; yl.sf /= 16.f; -#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block - yl.s0123 = *(global float4 *)(yb); - yl.s4567 = *(global float4 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; - sumy += yl.s3; - sumy += yl.s4; yl.s4 /= 16.f; - sumy += yl.s5; yl.s5 /= 16.f; - sumy += yl.s6; yl.s6 /= 16.f; - sumy += yl.s7; yl.s7 /= 16.f; -#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block - yl.s01 = *(global float2 *)(yb); - yl.s23 = *(global float2 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; yl.s2 /= 16.f; - sumy += yl.s3; yl.s3 /= 16.f; -#endif - - sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, *(d + ib + 0*nb), sumy, yl); - sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, *(d + ib + 1*nb), sumy, yl); -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, *(d + ib + 2*nb), sumy, yl); - sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, *(d + ib + 3*nb), sumy, yl); -#endif -#if N_DST == 8 || N_DST == 16 - sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, *(d + ib + 4*nb), sumy, yl); - sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, *(d + ib + 5*nb), sumy, yl); - sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, *(d + ib + 6*nb), sumy, yl); - sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, *(d + ib + 7*nb), sumy, yl); -#endif - - yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK); - } - - SUM_TY tot = (SUM_TY)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1) -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) -#endif -#if N_DST == 8 || N_DST == 16 - , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5) - , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) -#endif - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } -#endif -#if N_DST == 8 || N_DST == 16 - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } -#endif - } -} - -//------------------------------------------------------------------------------ -// Using image1d_buffer_t - -#if defined(cl_qcom_subgroup_shuffle) -#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable -float qcom_sub_group_reduce_add(float sum) { - sum += qcom_sub_group_shuffle_down(sum, 32, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 16, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 8, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 4, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 2, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 1, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - return sum; -} -#define sub_group_reduce_add qcom_sub_group_reduce_add -#else -#define sub_group_reduce_add sub_group_reduce_add -#endif - -#undef THREADS_PER_BLK -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define THREADS_PER_BLK 4 // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 16 -#elif defined (ADRENO_GPU) -#define THREADS_PER_BLK 4 -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block -# define ACT_TY float16 -# define Q_BLK_LD_TY float2 -# define EXTRACT_BLK_DATA(tmp, part) *((float2*)&tmp + part) -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v8 -#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block -# define ACT_TY float8 -# define Q_BLK_LD_TY float -# define EXTRACT_BLK_DATA(tmp, part) *((float*)&tmp + part) -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v4 -#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block -# define ACT_TY float4 -# define Q_BLK_LD_TY half -# define EXTRACT_BLK_DATA(tmp, part) *((half*)&tmp + part) -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v2 -#endif - -#define BTYES_PER_THREAD_IN_BLK (QK4_0/2/THREADS_PER_BLK) - -#if N_DST == 2 -# define SUM_TY float2 -#elif N_DST == 4 -# define SUM_TY float4 -#elif N_DST == 8 -# define SUM_TY float8 -#elif N_DST == 16 -# define SUM_TY float16 -#endif - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_flat_img_v0( - read_only image1d_buffer_t src0_q, - read_only image1d_buffer_t src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - const int nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - int ix = get_sub_group_local_id()/THREADS_PER_BLK; - int il = get_sub_group_local_id()%THREADS_PER_BLK; - - global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il; - - // Registers for caching activation - ACT_TY yl = 0.f; - - // Registers for caching quants - Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0; -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0; -#endif -#if N_DST == 8 || N_DST == 16 - Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0; -#endif - - // Partial sum - SUM_TY sumf = 0.f; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) { - float sumy = 0.f;; - - float4 tmp; - tmp = read_imagef(src0_q, offset0_q + ib + 0*nb); - q_blk_0 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 1*nb); - q_blk_1 = EXTRACT_BLK_DATA(tmp, il); -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - tmp = read_imagef(src0_q, offset0_q + ib + 2*nb); - q_blk_2 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 3*nb); - q_blk_3 = EXTRACT_BLK_DATA(tmp, il); -#endif -#if N_DST == 8 || N_DST == 16 - tmp = read_imagef(src0_q, offset0_q + ib + 4*nb); - q_blk_4 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 5*nb); - q_blk_5 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 6*nb); - q_blk_6 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 7*nb); - q_blk_7 = EXTRACT_BLK_DATA(tmp, il); -#endif - - // Load activation -#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block - yl.s01234567 = *(global float8 *)(yb); - yl.s89abcdef = *(global float8 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; - sumy += yl.s3; - sumy += yl.s4; - sumy += yl.s5; - sumy += yl.s6; - sumy += yl.s7; - sumy += yl.s8; yl.s8 /= 16.f; - sumy += yl.s9; yl.s9 /= 16.f; - sumy += yl.sa; yl.sa /= 16.f; - sumy += yl.sb; yl.sb /= 16.f; - sumy += yl.sc; yl.sc /= 16.f; - sumy += yl.sd; yl.sd /= 16.f; - sumy += yl.se; yl.se /= 16.f; - sumy += yl.sf; yl.sf /= 16.f; -#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block - yl.s0123 = *(global float4 *)(yb); - yl.s4567 = *(global float4 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; - sumy += yl.s3; - sumy += yl.s4; yl.s4 /= 16.f; - sumy += yl.s5; yl.s5 /= 16.f; - sumy += yl.s6; yl.s6 /= 16.f; - sumy += yl.s7; yl.s7 /= 16.f; -#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block - yl.s01 = *(global float2 *)(yb); - yl.s23 = *(global float2 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; yl.s2 /= 16.f; - sumy += yl.s3; yl.s3 /= 16.f; -#endif - - sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, read_imageh(src0_d, offset0_d + ib + 0*nb).s0, sumy, yl); - sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, read_imageh(src0_d, offset0_d + ib + 1*nb).s0, sumy, yl); -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, read_imageh(src0_d, offset0_d + ib + 2*nb).s0, sumy, yl); - sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, read_imageh(src0_d, offset0_d + ib + 3*nb).s0, sumy, yl); -#endif -#if N_DST == 8 || N_DST == 16 - sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, read_imageh(src0_d, offset0_d + ib + 4*nb).s0, sumy, yl); - sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, read_imageh(src0_d, offset0_d + ib + 5*nb).s0, sumy, yl); - sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, read_imageh(src0_d, offset0_d + ib + 6*nb).s0, sumy, yl); - sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, read_imageh(src0_d, offset0_d + ib + 7*nb).s0, sumy, yl); -#endif - - yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK); - } - - SUM_TY tot = (SUM_TY)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1) -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) -#endif -#if N_DST == 8 || N_DST == 16 - , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5) - , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) -#endif - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } -#endif -#if N_DST == 8 || N_DST == 16 - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } -#endif - } -} - -//------------------------------------------------------------------------------ -// kernel_mul_mv_q6_K_f32 -//------------------------------------------------------------------------------ - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 1 // number of rows each SIMD group works on -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // SIMD group size -#elif defined (ADRENO_GPU) -#define N_DST 1 -#define N_SIMDGROUP 2 -#define N_SIMDWIDTH 64 -#endif - -#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mv_q6_K_f32( - global void * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - uchar kmask1 = 0x03; - uchar kmask2 = 0x0C; - uchar kmask3 = 0x30; - uchar kmask4 = 0xC0; - - int nb = ne00/QK_K; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - int row = N_SIMDGROUP * r0 + get_sub_group_id(); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0; - global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float sumf = 0; - - // For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a - // block. Values in a subblock shares a scale that is quantized with 8 bits; - // the entire block shares a single floating point scale. - // For work distribution, each thread processes a subblock (16 weights), hence - // 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16 - // (super) blocks -- this is the block stride. - // The 16 threads that process a (super) block are split into 2 portions, each has - // 8 threads; each portion works on 8 subblocks. - // For subgroup of 16 threads, the entire subgroup works on a single (super) block - // before moving to the next (super) block. Thread0 - thread7 work on the - // first 8 subblocks; thread8 - thread15 works on the last 8 subblocks. - // Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on - // subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but - // works on a total of 16 weight values. - int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0 - int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1 - int ip = tid/8; // first or second half of (super) block (0 or 1) - int il = tid%8; // each half has 8 parts, one per scale - int n = 4; // 4 scales at a time (and 4 sums) - int l0 = n*il; // offset into half-block, 0..28 - int is = 8*ip + l0/16; // 0, 1, 8, 9 - - int y_offset = 128*ip + l0; - int q_offset_l = 64*ip + l0; - int q_offset_h = 32*ip + l0; - - for (int i = ix; i < nb; i += BLOCK_STRIDE) { - - global uint8_t * q1 = x[i].ql + q_offset_l; - global uint8_t * q2 = q1 + QK_K/8; - global uint8_t * qh = x[i].qh + q_offset_h; - global int8_t * sc = x[i].scales + is; - - global float * y = yy + i * QK_K + y_offset; - - float dall = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - - sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f); - sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f); - sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & kmask3) << 0)) - 32.f); - sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f); - - sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f); - sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f); - sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & kmask3) << 0)) - 32.f); - sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f); - - sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f); - sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f); - sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & kmask3) << 0)) - 32.f); - sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f); - - sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f); - sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f); - sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & kmask3) << 0)) - 32.f); - sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f); - - sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]); - } - - float tot = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[r1*ne0 + im*ne0*ne1 + row] = tot; - } -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl deleted file mode 100644 index cd4e0afbad2..00000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl +++ /dev/null @@ -1,26 +0,0 @@ -// 16-bit transpose, loading/storing a 4x4 tile of elements - -#pragma OPENCL EXTENSION cl_khr_fp16 : enable - -kernel void kernel_transpose_16( - __read_only image1d_buffer_t input, - __write_only image1d_buffer_t output, - const uint rows, - const uint cols -) { - - const int i = get_global_id(0); - const int j = get_global_id(1); - const int i_2 = i<<2; - const int j_2 = j<<2; - - half4 temp0 = read_imageh(input, (j_2+0)*cols+i); - half4 temp1 = read_imageh(input, (j_2+1)*cols+i); - half4 temp2 = read_imageh(input, (j_2+2)*cols+i); - half4 temp3 = read_imageh(input, (j_2+3)*cols+i); - - write_imageh(output, (i_2+0)*rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); - write_imageh(output, (i_2+1)*rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); - write_imageh(output, (i_2+2)*rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); - write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl deleted file mode 100644 index 914ec0193e7..00000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl +++ /dev/null @@ -1,25 +0,0 @@ -// 32-bit transpose, loading/storing a 4x4 tile of elements - -kernel void kernel_transpose_32( - __read_only image1d_buffer_t input, - __write_only image1d_buffer_t output, - const uint rows, - const uint cols -) { - - const int i = get_global_id(0); - const int j = get_global_id(1); - const int i_2 = i<<2; - const int j_2 = j<<2; - - float4 temp0 = read_imagef(input, (j_2+0)*cols+i); - float4 temp1 = read_imagef(input, (j_2+1)*cols+i); - float4 temp2 = read_imagef(input, (j_2+2)*cols+i); - float4 temp3 = read_imagef(input, (j_2+3)*cols+i); - - write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); - write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); - write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); - write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); - -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl deleted file mode 100644 index d3bd1fabb76..00000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl +++ /dev/null @@ -1,35 +0,0 @@ -// 32-bit transpose, loading/storing a 4x4 tile of elements -// Only used for activations -// converts to FP16 -// also adds zero padding for non multiple of 8 prompt lengths -#pragma OPENCL EXTENSION cl_khr_fp16 : enable - -kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) { - - const int i = get_global_id(0); - const int j = get_global_id(1); - const int i_2 = i<<2; - const int j_2 = j<<2; - half4 temp0 = {0,0,0,0}; // initialize outputs to 0 - half4 temp1 = {0,0,0,0}; - half4 temp2 = {0,0,0,0}; - half4 temp3 = {0,0,0,0}; - - if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0 - temp0 = read_imageh(input, (j_2+0)*cols+i); - } - if((j_2+1)*cols+i*4+3 < rows*cols*16){ - temp1 = read_imageh(input, (j_2+1)*cols+i); - } - if((j_2+2)*cols+i*4+3 < rows*cols*16){ - temp2 = read_imageh(input, (j_2+2)*cols+i); - } - if((j_2+3)*cols+i*4+3 < rows*cols*16){ - temp3 = read_imageh(input, (j_2+3)*cols+i); - } - - write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding - write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); - write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); - write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); -} diff --git a/ggml/src/ggml-opencl/kernels/im2col_f16.cl b/ggml/src/ggml-opencl/kernels/im2col_f16.cl new file mode 100644 index 00000000000..b84c8984653 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/im2col_f16.cl @@ -0,0 +1,57 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_im2col_f16( + global float * src1, + ulong offset1, + global half * dst, + ulong offsetd, + ulong batch_offset, + ulong delta_offset, + long IW, + long IH, + long IC, + long OW, + long OH, + long KW, + long KH, + long pelements, + long CHW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1 +) { + long i = get_global_id(0); + if (i >= pelements) { + return; + } + + src1 = (global float*)((global char*)src1 + offset1); + dst = (global half*)((global char*)dst + offsetd); + + long ksize = OW * (KH > 1 ? KW : 1); + long kx = i / ksize; + long kd = kx * ksize; + long ky = (i - kd) / OW; + long ix = i % OW; + + long oh = get_group_id(1); + long batch = get_group_id(2) / IC; + long ic = get_group_id(2) % IC; + + long iiw = ix * s0 + kx * d0 - p0; + long iih = oh * s1 + ky * d1 - p1; + + long offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + long offset_src = ic * delta_offset + batch * batch_offset; + dst[offset_dst] = src1[offset_src + iih * IW + iiw]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/im2col_f32.cl b/ggml/src/ggml-opencl/kernels/im2col_f32.cl new file mode 100644 index 00000000000..4bf65e4eaaf --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/im2col_f32.cl @@ -0,0 +1,57 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_im2col_f32( + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + ulong batch_offset, + ulong delta_offset, + long IW, + long IH, + long IC, + long OW, + long OH, + long KW, + long KH, + long pelements, + long CHW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1 +) { + long i = get_global_id(0); + if (i >= pelements) { + return; + } + + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + long ksize = OW * (KH > 1 ? KW : 1); + long kx = i / ksize; + long kd = kx * ksize; + long ky = (i - kd) / OW; + long ix = i % OW; + + long oh = get_group_id(1); + long batch = get_group_id(2) / IC; + long ic = get_group_id(2) % IC; + + long iiw = ix * s0 + kx * d0 - p0; + long iih = oh * s1 + ky * d1 - p1; + + long offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + long offset_src = ic * delta_offset + batch * batch_offset; + dst[offset_dst] = src1[offset_src + iih * IW + iiw]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul.cl b/ggml/src/ggml-opencl/kernels/mul.cl new file mode 100644 index 00000000000..2a2b4eb70a1 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul.cl @@ -0,0 +1,79 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// mul +//------------------------------------------------------------------------------ +kernel void kernel_mul( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_mul_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] * src1[idx1]; +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl b/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl similarity index 100% rename from ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl rename to ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl new file mode 100644 index 00000000000..9393b549415 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl @@ -0,0 +1,118 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define N_F16_F16 4 + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3) +{ + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F16_F16; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half * x = (global half *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * y = (global half *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (half) x[i] * (half) y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global half4 * x4 = (global half4 *)x; + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * y = (global half *) (src1 + offset_src1); + global half4 * y4 = (global half4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (half) x4[i].s0 * y4[i].s0; + sumf += (half) x4[i].s1 * y4[i].s1; + sumf += (half) x4[i].s2 * y4[i].s2; + sumf += (half) x4[i].s3 * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (half) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl new file mode 100644 index 00000000000..e52d3c6d475 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl @@ -0,0 +1,118 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define N_F16_F32 4 + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F16_F32; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half * x = (global half *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += convert_float(x[i]) * y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global half4 * x4 = (global half4 *)x; + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + global float4 * y4 = (global float4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += convert_float(x4[i].s0) * y4[i].s0; + sumf += convert_float(x4[i].s1) * y4[i].s1; + sumf += convert_float(x4[i].s2) * y4[i].s2; + sumf += convert_float(x4[i].s3) * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl new file mode 100644 index 00000000000..28d30212cda --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl @@ -0,0 +1,94 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32_1row( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * x = (global half *) (src0 + offset_src0); + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + if (ne00 < 128) { + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } else { + global half4 * x4 = (global half4 *) x; + global float4 * y4 = (global float4 *) y; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (float) x4[i].s0 * y4[i].s0; + sumf += (float) x4[i].s1 * y4[i].s1; + sumf += (float) x4[i].s2 * y4[i].s2; + sumf += (float) x4[i].s3 * y4[i].s3; + } + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl new file mode 100644 index 00000000000..cdf8197c470 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl @@ -0,0 +1,84 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +// Assumes row size (ne00) is a multiple of 4 +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32_l4( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int nrows = ne11; + int r0 = get_group_id(0); + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half4 * x4 = (global half4 *) (src0 + offset_src0); + + for (int r1 = 0; r1 < nrows; ++r1) { + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float4 * y4 = (global float4 *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += convert_float(x4[i].s0) * y4[i].s0; + sumf += convert_float(x4[i].s1) * y4[i].s1; + sumf += convert_float(x4[i].s2) * y4[i].s2; + sumf += convert_float(x4[i].s3) * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl new file mode 100644 index 00000000000..ec71b875652 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl @@ -0,0 +1,118 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define N_F32_F32 4 + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f32_f32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F32_F32; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global float * x = (global float *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global float4 * x4 = (global float4 *)x; + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + global float4 * y4 = (global float4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (float) x4[i].s0 * y4[i].s0; + sumf += (float) x4[i].s1 * y4[i].s1; + sumf += (float) x4[i].s2 * y4[i].s2; + sumf += (float) x4[i].s3 * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl new file mode 100644 index 00000000000..52141e0ed55 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl @@ -0,0 +1,192 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +//------------------------------------------------------------------------------ +// mul_vec_q_n_f32 +//------------------------------------------------------------------------------ +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_4_0_dot_y( + global struct block_q4_0 * qb_curr, + float sumy, + private float * yl, + int il +) { + float d = qb_curr->d; + float2 acc = 0.f; + global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { + acc.s0 += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc.s1 += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); + } + return d * (sumy * -8.f + acc.s0 + acc.s1); +} + +#ifdef INTEL_GPU +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global + // id of a SIMD group in the grid. + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; // src1 vector cache + float sumf[N_DST]={0.f}; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + for (int i = 0; i < 8; i += 2) { + sumy += yb[i] + yb[i+1]; + yl[i+0] = yb[i+ 0]; + yl[i+1] = yb[i+ 1]/256.f; + sumy += yb[i+16] + yb[i+17]; + yl[i+8] = yb[i+16]/16.f; + yl[i+9] = yb[i+17]/4096.f; + } + + for (int row = 0; row < N_DST; row++) { + sumf[row] += block_q_4_0_dot_y(x+ib+row*nb, sumy, yl, il); + } + + // One thread in a SIMD group (i.e., subgroup) handles a half block, + // hence then entire SIMD group handles SIMDWIDTH/2 blocks. + // y points to the activation matrix (of type float). Therefore for + // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because + // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of + // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + // The above does not work for Adreno - it produces incorrect results for + // row = 1, 2, 3 and only row = 0 gives the correct result. + // If N_DST is changed, the below array must be initialized accordingly. + // This also seems to perform better on Intel. + float tot[N_DST] = { + sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]), + sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])}; + for (int row = 0; row < N_DST; ++row) { + if (get_sub_group_local_id() == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row]; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl new file mode 100644 index 00000000000..3eebab8f0f2 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl @@ -0,0 +1,307 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +inline float mm_block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#ifdef INTEL_GPU +#define N_DST 16 // each SIMD group works on 8 rows (in weights matrix) +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 16 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif +// +// This variant performs 1d blocking with 16x output. +// Eeach simdgroup outputs 16 values on `n0` dim (row in the output matrix). +// +inline void mul_mat_q_n_f32_1d_16x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float16 sumf = (float16)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + sumf.s8 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 8*nb*QK4_0/2, d + ib + 8*nb, sumy, yl, il); + sumf.s9 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 9*nb*QK4_0/2, d + ib + 9*nb, sumy, yl, il); + sumf.sa += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 10*nb*QK4_0/2, d + ib + 10*nb, sumy, yl, il); + sumf.sb += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 11*nb*QK4_0/2, d + ib + 11*nb, sumy, yl, il); + + sumf.sc += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 12*nb*QK4_0/2, d + ib + 12*nb, sumy, yl, il); + sumf.sd += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 13*nb*QK4_0/2, d + ib + 13*nb, sumy, yl, il); + sumf.se += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 14*nb*QK4_0/2, d + ib + 14*nb, sumy, yl, il); + sumf.sf += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 15*nb*QK4_0/2, d + ib + 15*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float16 tot = (float16)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7), + + sub_group_reduce_add(sumf.s8), sub_group_reduce_add(sumf.s9), + sub_group_reduce_add(sumf.sa), sub_group_reduce_add(sumf.sb), + sub_group_reduce_add(sumf.sc), sub_group_reduce_add(sumf.sd), + sub_group_reduce_add(sumf.se), sub_group_reduce_add(sumf.sf) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + + if (first_row + 8 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 8] = tot.s8; + } + if (first_row + 9 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 9] = tot.s9; + } + if (first_row + 10 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 10] = tot.sa; + } + if (first_row + 11 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 11] = tot.sb; + } + + if (first_row + 12 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 12] = tot.sc; + } + if (first_row + 13 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 13] = tot.sd; + } + if (first_row + 14 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 14] = tot.se; + } + if (first_row + 15 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 15] = tot.sf; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_1d_16x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_mat_q_n_f32_1d_16x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl new file mode 100644 index 00000000000..38024d00ad5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl @@ -0,0 +1,265 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +inline float mm_block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#ifdef INTEL_GPU +#define N_DST 8 // each SIMD group works on 8 rows (in weights matrix) +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif +// +// This variant performs 1d blocking with 8x output. +// Eeach simdgroup outputs 8 values on `n0` dim (row in the output matrix). +// +inline void mul_mat_q_n_f32_1d_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float8 sumf = (float8)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_1d_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_mat_q_n_f32_1d_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl new file mode 100644 index 00000000000..aed1ce7b260 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl @@ -0,0 +1,272 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +// This function requires the original shuffled weights. +// As a reminder, the original weights are shuffled so that (q[0], q[16]) are +// packed together in a byte, so are (q[1], q[17]) and so on. +inline float block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +// +// This variant outputs 8 values. +// +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 8 // each SIMD group works on 8 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 32 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float8 sumf = 0.f; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl new file mode 100644 index 00000000000..92955217971 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl @@ -0,0 +1,254 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +// +// This variant unrolls the loops and uses vector types instead of pointers. +// It improves performance on Adreno but not so much on Intel. +// +inline float block_q_4_0_dot_y_v( + global struct block_q4_0 * qb_curr, + float sumy, + float16 yl, + int il +) { + float d = qb_curr->d; + float acc = 0.f; + global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_v( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global + // id of a SIMD group in the grid. + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; // src1 vector cache + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_v(x+ib+0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_v(x+ib+1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_v(x+ib+2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_v(x+ib+3*nb, sumy, yl, il); + + // One thread in a SIMD group (i.e., subgroup) handles a half block, + // hence then entire SIMD group handles SIMDWIDTH/2 blocks. + // y points to the activation matrix (of type float). Therefore for + // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because + // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of + // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + // The above does not work for Adreno - it produces incorrect results for + // row = 1, 2, 3 and only row = 0 gives the correct result. + // If N_DST is changed, the below array must be initialized accordingly. + // This also seems to perform better on Intel. + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_v( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_v(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl new file mode 100644 index 00000000000..8a17b9aae63 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl @@ -0,0 +1,190 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q6_K +//------------------------------------------------------------------------------ +// 6-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 6.5625 bits per weight +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +} block_q6_K; + +//------------------------------------------------------------------------------ +// kernel_mul_mv_q6_K_f32 +//------------------------------------------------------------------------------ + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 1 // number of rows each SIMD group works on +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // SIMD group size +#elif defined (ADRENO_GPU) +#define N_DST 1 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q6_K_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + uchar kmask1 = 0x03; + uchar kmask2 = 0x0C; + uchar kmask3 = 0x30; + uchar kmask4 = 0xC0; + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int row = N_SIMDGROUP * r0 + get_sub_group_id(); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0; + global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float sumf = 0; + + // For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a + // block. Values in a subblock shares a scale that is quantized with 8 bits; + // the entire block shares a single floating point scale. + // For work distribution, each thread processes a subblock (16 weights), hence + // 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16 + // (super) blocks -- this is the block stride. + // The 16 threads that process a (super) block are split into 2 portions, each has + // 8 threads; each portion works on 8 subblocks. + // For subgroup of 16 threads, the entire subgroup works on a single (super) block + // before moving to the next (super) block. Thread0 - thread7 work on the + // first 8 subblocks; thread8 - thread15 works on the last 8 subblocks. + // Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on + // subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but + // works on a total of 16 weight values. + int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0 + int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1 + int ip = tid/8; // first or second half of (super) block (0 or 1) + int il = tid%8; // each half has 8 parts, one per scale + int n = 4; // 4 scales at a time (and 4 sums) + int l0 = n*il; // offset into half-block, 0..28 + int is = 8*ip + l0/16; // 0, 1, 8, 9 + + int y_offset = 128*ip + l0; + int q_offset_l = 64*ip + l0; + int q_offset_h = 32*ip + l0; + + for (int i = ix; i < nb; i += BLOCK_STRIDE) { + + global uint8_t * q1 = x[i].ql + q_offset_l; + global uint8_t * q2 = q1 + QK_K/8; + global uint8_t * qh = x[i].qh + q_offset_h; + global int8_t * sc = x[i].scales + is; + + global float * y = yy + i * QK_K + y_offset; + + float dall = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + + sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f); + sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f); + sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & kmask3) << 0)) - 32.f); + sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f); + sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f); + sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & kmask3) << 0)) - 32.f); + sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f); + sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f); + sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & kmask3) << 0)) - 32.f); + sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f); + sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f); + sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & kmask3) << 0)) - 32.f); + sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f); + + sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]); + } + + float tot = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[r1*ne0 + im*ne0*ne1 + row] = tot; + } +} diff --git a/ggml/src/ggml-opencl/kernels/norm.cl b/ggml/src/ggml-opencl/kernels/norm.cl new file mode 100644 index 00000000000..43167ba4d22 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/norm.cl @@ -0,0 +1,81 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// norm +//------------------------------------------------------------------------------ +kernel void kernel_norm( + global void * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + float eps, + local float * sum +) { + src0 = (global void*)((global char*)src0 + offset0); + dst = (global void*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); + + // MEAN + // parallel sum + sum[get_local_id(0)] = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + sum[get_local_id(0)] += x[i00]; + } + // reduce + barrier(CLK_LOCAL_MEM_FENCE); + for (uint i = get_local_size(0)/2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + float mean = sum[0] / ne00; + + // recenter and VARIANCE + barrier(CLK_LOCAL_MEM_FENCE); + global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + sum[get_local_id(0)] = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + y[i00] = x[i00] - mean; + sum[get_local_id(0)] += y[i00] * y[i00]; + } + + // reduce + barrier(CLK_LOCAL_MEM_FENCE); + for (uint i = get_local_size(0)/2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + float variance = sum[0] / ne00; + + float scale = 1.0f/sqrt(variance + eps); + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + y[i00] = y[i00] * scale; + } +} diff --git a/ggml/src/ggml-opencl/kernels/relu.cl b/ggml/src/ggml-opencl/kernels/relu.cl new file mode 100644 index 00000000000..60ff28a61a0 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/relu.cl @@ -0,0 +1,16 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// relu +//------------------------------------------------------------------------------ +kernel void kernel_relu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = fmax(0.0f, src0[get_global_id(0)]); +} diff --git a/ggml/src/ggml-opencl/kernels/rms_norm.cl b/ggml/src/ggml-opencl/kernels/rms_norm.cl new file mode 100644 index 00000000000..9d21f3398ec --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/rms_norm.cl @@ -0,0 +1,96 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// rms_norm +//------------------------------------------------------------------------------ +// This kernel depends on subgroup size. +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_rms_norm( + global void * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + float eps, + local float * sum // Note, the size depends on number of subgroups +) { + src0 = (global void*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); + global float * x_scalar = (global float *) x; + float4 sumf = 0; + float all_sum = 0; + + // parallel sum + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3; + all_sum = sub_group_reduce_add(all_sum); + if (get_sub_group_local_id() == 0) { + sum[get_sub_group_id()] = all_sum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + // broadcast + for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + } + if (get_local_id(0) == 0) { + for (int i = 4 * (ne00 / 4); i < ne00; i++) { + sum[0] += x_scalar[i]; + } + sum[0] /= ne00; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + const float mean = sum[0]; + const float scale = 1.0f/sqrt(mean + eps); + + global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global float * y_scalar = (global float *) y; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + y[i00] = x[i00] * scale; + } + if (get_local_id(0) == 0) { + for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { + y_scalar[i00] = x_scalar[i00] * scale; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/rope.cl b/ggml/src/ggml-opencl/kernels/rope.cl new file mode 100644 index 00000000000..0247730c036 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/rope.cl @@ -0,0 +1,721 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// kernel_rope +//------------------------------------------------------------------------------ +float rope_yarn_ramp(float low, float high, int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +float2 rope_yarn( + float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); + } + return (float2)(cos(theta) * mscale, sin(theta) * mscale); +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +} + +float2 rope_yarn_corr_dims( + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow +) { + // start and end correction dims + return (float2)( + max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))), + min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))) + ); +} + +kernel void kernel_rope_norm_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + float theta = theta_base * pow(freq_base, inv_ndims*i0); + + float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + float x0 = src[0]; + float x1 = src[1]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_norm_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + float theta = theta_base * pow(freq_base, inv_ndims*i0); + + float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + float x0 = src[0]; + float x1 = src[1]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_neox_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_neox_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_multi_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const int sector = (i0 / 2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_multi_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const int sector = (i0 / 2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_vision_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + int ic = i0/2; + + const int sector = (i0/2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + const int p = sector; + theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); + } else if (sector >= sections.s0 && sector < sec_w) { + const int p = sector - sections.s0; + theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); + } + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } +} + +kernel void kernel_rope_vision_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + int ic = i0/2; + + const int sector = (i0/2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + const int p = sector; + theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); + } else if (sector >= sections.s0 && sector < sec_w) { + const int p = sector - sections.s0; + theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); + } + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } +} diff --git a/ggml/src/ggml-opencl/kernels/scale.cl b/ggml/src/ggml-opencl/kernels/scale.cl new file mode 100644 index 00000000000..8cfd518fa5a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/scale.cl @@ -0,0 +1,16 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// scale +//------------------------------------------------------------------------------ +kernel void kernel_scale( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + float scale +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + dst[get_global_id(0)] = src0[get_global_id(0)] * scale; +} diff --git a/ggml/src/ggml-opencl/kernels/silu.cl b/ggml/src/ggml-opencl/kernels/silu.cl new file mode 100644 index 00000000000..1d95e1b50fd --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/silu.cl @@ -0,0 +1,30 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// silu +//------------------------------------------------------------------------------ +kernel void kernel_silu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x / (1.0f + exp(-x)); +} + +kernel void kernel_silu_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x / (1.0f + exp(-x)); +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl new file mode 100644 index 00000000000..62c05369a87 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl @@ -0,0 +1,87 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_4_f16( + global float * src0, + ulong offset0, + global half * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float *)((global char *)src0 + offset0); + src1 = (global half *)((global char *)src1 + offset1); + dst = (global float *)((global char *)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0; + global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)); + } + float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); + + const float max = sub_group_reduce_max(lmax); + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)) - max); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + pdst4[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl new file mode 100644 index 00000000000..d562774eaba --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl @@ -0,0 +1,87 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_4( + global float * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0; + global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); + + const float max = sub_group_reduce_max(lmax); + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + pdst4[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_f16.cl new file mode 100644 index 00000000000..d38d099671e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_f16.cl @@ -0,0 +1,86 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_f16( + global float * src0, + ulong offset0, + global half * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float *)((global char *)src0 + offset0); + src1 = (global half *)((global char *)src1 + offset1); + dst = (global float *)((global char *)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0; + global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float max = sub_group_reduce_max(lmax); + + // parallel sum + float lsum = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // wish to compute it twice. + pdst[i00] = exp_psrc0; + } + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + pdst[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_f32.cl new file mode 100644 index 00000000000..001b587abe3 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_f32.cl @@ -0,0 +1,86 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max( + global float * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0; + global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float max = sub_group_reduce_max(lmax); + + // parallel sum + float lsum = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // wish to compute it twice. + pdst[i00] = exp_psrc0; + } + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + pdst[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/transpose.cl b/ggml/src/ggml-opencl/kernels/transpose.cl new file mode 100644 index 00000000000..a11490b304c --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/transpose.cl @@ -0,0 +1,84 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +// 16-bit transpose, loading/storing a 4x4 tile of elements +kernel void kernel_transpose_16( + __read_only image1d_buffer_t input, + __write_only image1d_buffer_t output, + const uint rows, + const uint cols +) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + + half4 temp0 = read_imageh(input, (j_2+0)*cols+i); + half4 temp1 = read_imageh(input, (j_2+1)*cols+i); + half4 temp2 = read_imageh(input, (j_2+2)*cols+i); + half4 temp3 = read_imageh(input, (j_2+3)*cols+i); + + write_imageh(output, (i_2+0)*rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); + write_imageh(output, (i_2+1)*rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imageh(output, (i_2+2)*rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); +} + +// 32-bit transpose, loading/storing a 4x4 tile of elements +kernel void kernel_transpose_32( + __read_only image1d_buffer_t input, + __write_only image1d_buffer_t output, + const uint rows, + const uint cols +) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + + float4 temp0 = read_imagef(input, (j_2+0)*cols+i); + float4 temp1 = read_imagef(input, (j_2+1)*cols+i); + float4 temp2 = read_imagef(input, (j_2+2)*cols+i); + float4 temp3 = read_imagef(input, (j_2+3)*cols+i); + + write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); + write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); + +} + +// 32-bit transpose, loading/storing a 4x4 tile of elements +// Only used for activations +// converts to FP16 +// also adds zero padding for non multiple of 8 prompt lengths +kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + half4 temp0 = {0,0,0,0}; // initialize outputs to 0 + half4 temp1 = {0,0,0,0}; + half4 temp2 = {0,0,0,0}; + half4 temp3 = {0,0,0,0}; + + if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0 + temp0 = read_imageh(input, (j_2+0)*cols+i); + } + if((j_2+1)*cols+i*4+3 < rows*cols*16){ + temp1 = read_imageh(input, (j_2+1)*cols+i); + } + if((j_2+2)*cols+i*4+3 < rows*cols*16){ + temp2 = read_imageh(input, (j_2+2)*cols+i); + } + if((j_2+3)*cols+i*4+3 < rows*cols*16){ + temp3 = read_imageh(input, (j_2+3)*cols+i); + } + + write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding + write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); +} diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 862b9b66617..a0667b7d702 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1,6 +1,7 @@ #include "ggml-rpc.h" #include "ggml-impl.h" #include "ggml-backend-impl.h" +#include "ggml-cpp.h" #include #include @@ -91,12 +92,19 @@ enum rpc_cmd { RPC_CMD_GET_DEVICE_MEMORY, RPC_CMD_INIT_TENSOR, RPC_CMD_GET_ALLOC_SIZE, + RPC_CMD_HELLO, RPC_CMD_COUNT, }; // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold const size_t HASH_THRESHOLD = 10 * 1024 * 1024; +struct rpc_msg_hello_rsp { + uint8_t major; + uint8_t minor; + uint8_t patch; +}; + struct rpc_msg_get_alloc_size_req { rpc_tensor tensor; }; @@ -399,6 +407,20 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm // RPC client-side implementation +static bool check_server_version(const std::shared_ptr & sock) { + rpc_msg_hello_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response)); + GGML_ASSERT(status); + if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) { + fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch); + return false; + } + if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) { + fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch); + } + return true; +} + static std::shared_ptr get_socket(const std::string & endpoint) { static std::mutex mutex; std::lock_guard lock(mutex); @@ -432,6 +454,9 @@ static std::shared_ptr get_socket(const std::string & endpoint) { if (sock == nullptr) { return nullptr; } + if (!check_server_version(sock)) { + return nullptr; + } GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd); sockets[endpoint] = sock; return sock; @@ -817,6 +842,7 @@ class rpc_server { } ~rpc_server(); + void hello(rpc_msg_hello_rsp & response); void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); void get_alignment(rpc_msg_get_alignment_rsp & response); void get_max_size(rpc_msg_get_max_size_rsp & response); @@ -845,6 +871,13 @@ class rpc_server { std::unordered_set buffers; }; +void rpc_server::hello(rpc_msg_hello_rsp & response) { + response.major = RPC_PROTO_MAJOR_VERSION; + response.minor = RPC_PROTO_MINOR_VERSION; + response.patch = RPC_PROTO_PATCH_VERSION; + GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch); +} + bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) { ggml_backend_buffer_type_t buft; struct ggml_init_params params { @@ -853,12 +886,13 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_ /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); if (tensor == nullptr) { GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n"); - ggml_free(ctx); return false; } @@ -871,7 +905,6 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_ response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor); - ggml_free(ctx); return true; } @@ -985,11 +1018,12 @@ bool rpc_server::set_tensor(const std::vector & input) { /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); if (tensor == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); - ggml_free(ctx); return false; } GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size); @@ -1016,7 +1050,6 @@ bool rpc_server::set_tensor(const std::vector & input) { printf("[%s] saved to '%s'\n", __func__, cache_file.c_str()); } ggml_backend_tensor_set(tensor, data, offset, size); - ggml_free(ctx); return true; } @@ -1060,11 +1093,12 @@ bool rpc_server::set_tensor_hash(const std::vector & input, rpc_msg_set /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); if (tensor == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); - ggml_free(ctx); return false; } GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size, *hash); @@ -1080,7 +1114,6 @@ bool rpc_server::set_tensor_hash(const std::vector & input, rpc_msg_set } ggml_backend_tensor_set(tensor, cached_file.data(), offset, size); response.result = 1; - ggml_free(ctx); return true; } @@ -1090,11 +1123,12 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); if (tensor == nullptr) { GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n"); - ggml_free(ctx); return false; } @@ -1110,11 +1144,9 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { // This pointer can either be passed around client/server, or probably better stored server-side and kept track of. // Currently unimplemented. GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n"); - ggml_free(ctx); return false; } - ggml_free(ctx); return true; } @@ -1124,11 +1156,12 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector< /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); if (tensor == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); - ggml_free(ctx); return false; } GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size); @@ -1147,7 +1180,6 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector< response.resize(request.size, 0); ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size); - ggml_free(ctx); return true; } @@ -1157,12 +1189,14 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * src = deserialize_tensor(ctx, &request.src); ggml_tensor * dst = deserialize_tensor(ctx, &request.dst); if (src == nullptr || dst == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__); - ggml_free(ctx); return false; } @@ -1180,7 +1214,6 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co dst_data + src_size, dst_base, dst_base + dst_buf_sz); - ggml_free(ctx); return false; } @@ -1188,7 +1221,6 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co __func__, (void*) src->buffer, (void*) dst->buffer); response.result = ggml_backend_buffer_copy_tensor(src, dst); - ggml_free(ctx); return true; } @@ -1242,7 +1274,9 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false); graph->n_nodes = n_nodes; std::unordered_map tensor_ptrs; @@ -1257,7 +1291,6 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph } ggml_status status = ggml_backend_graph_compute(backend, graph); response.result = status; - ggml_free(ctx); return true; } @@ -1270,8 +1303,24 @@ rpc_server::~rpc_server() { static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, sockfd_t sockfd, size_t free_mem, size_t total_mem) { rpc_server server(backend, cache_dir); + uint8_t cmd; + if (!recv_data(sockfd, &cmd, 1)) { + return; + } + // the first command sent by the client must be HELLO + if (cmd != RPC_CMD_HELLO) { + fprintf(stderr, "Expected HELLO command, update client\n"); + return; + } + if (!recv_msg(sockfd, nullptr, 0)) { + return; + } + rpc_msg_hello_rsp response; + server.hello(response); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } while (true) { - uint8_t cmd; if (!recv_data(sockfd, &cmd, 1)) { break; } @@ -1281,6 +1330,10 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, break; } switch (cmd) { + case RPC_CMD_HELLO: { + // HELLO command is handled above + return; + } case RPC_CMD_ALLOC_BUFFER: { rpc_msg_alloc_buffer_req request; if (!recv_msg(sockfd, &request, sizeof(request))) { diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 6747fd88361..6699b70bad0 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -27,6 +27,15 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp") file(GLOB GGML_SOURCES_SYCL "*.cpp") target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL}) +if (WIN32) + # To generate a Visual Studio solution, using Intel C++ Compiler for ggml-sycl is mandatory + if( ${CMAKE_GENERATOR} MATCHES "Visual Studio" AND NOT (${CMAKE_GENERATOR_TOOLSET} MATCHES "Intel C")) + set_target_properties(ggml-sycl PROPERTIES VS_PLATFORM_TOOLSET "Intel C++ Compiler 2025") + set(CMAKE_CXX_COMPILER "icx") + set(CMAKE_CXX_COMPILER_ID "IntelLLVM") + endif() +endif() + find_package(IntelSYCL) if (IntelSYCL_FOUND) # Use oneAPI CMake when possible diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index 73d807cab0b..de814ef91a0 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -13,6 +13,7 @@ #ifndef GGML_SYCL_BACKEND_HPP #define GGML_SYCL_BACKEND_HPP +#include "binbcast.hpp" #include "concat.hpp" #include "common.hpp" #include "conv.hpp" diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp new file mode 100644 index 00000000000..0a9d3a927c2 --- /dev/null +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -0,0 +1,350 @@ +#include "binbcast.hpp" + +#include +#include +#include + +#include "ggml.h" + +template +static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s00,*/ int s01, int s02, int s03, + /*int s10,*/ int s11, int s12, int s13, + const sycl::nd_item<3> &item_ct1) { + const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1)); + const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + + item_ct1.get_local_id(0)) / + ne3; + const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + + item_ct1.get_local_id(0)) % + ne3; + + if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + for (int i0 = i0s; i0 < ne0; + i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) { + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + } +} + +template +static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s00,*/ int s01, int s02, int s03, + /*int s10,*/ int s11, int s12, int s13, + const sycl::nd_item<3> &item_ct1) { + + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + const int i3 = i/(ne2*ne1*ne0); + const int i2 = (i/(ne1*ne0)) % ne2; + const int i1 = (i/ne0) % ne1; + const int i0 = i % ne0; + + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); +} + + +template +struct bin_bcast_sycl { + template + void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00, + const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11, + const int64_t ne12, const int64_t ne13, const int64_t ne0, const int64_t ne1, const int64_t ne2, + const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03, + const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0, + const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous, + const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) { + int nr0 = ne10 / ne0; + int nr1 = ne11/ne1; + int nr2 = ne12/ne2; + int nr3 = ne13/ne3; + + int nr[4] = { nr0, nr1, nr2, nr3 }; + + // collapse dimensions until first broadcast dimension + int64_t cne[] = {ne0, ne1, ne2, ne3}; + int64_t cne0[] = {ne00, ne01, ne02, ne03}; + int64_t cne1[] = {ne10, ne11, ne12, ne13}; + size_t cnb[] = {nb0, nb1, nb2, nb3}; + size_t cnb0[] = {nb00, nb01, nb02, nb03}; + size_t cnb1[] = {nb10, nb11, nb12, nb13}; + auto collapse = [](int64_t cne[]) { + cne[0] *= cne[1]; + cne[1] = cne[2]; + cne[2] = cne[3]; + cne[3] = 1; + }; + + auto collapse_nb = [](size_t cnb[], int64_t cne[]) { + cnb[1] *= cne[1]; + cnb[2] *= cne[2]; + cnb[3] *= cne[3]; + }; + + if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) { + for (int i = 0; i < 4; i++) { + if (nr[i] != 1) { + break; + } + if (i > 0) { + collapse_nb(cnb, cne); + collapse_nb(cnb0, cne0); + collapse_nb(cnb1, cne1); + collapse(cne); + collapse(cne0); + collapse(cne1); + } + } + } + { + int64_t ne0 = cne[0]; + int64_t ne1 = cne[1]; + int64_t ne2 = cne[2]; + int64_t ne3 = cne[3]; + + int64_t ne10 = cne1[0]; + int64_t ne11 = cne1[1]; + int64_t ne12 = cne1[2]; + int64_t ne13 = cne1[3]; + + size_t nb0 = cnb[0]; + size_t nb1 = cnb[1]; + size_t nb2 = cnb[2]; + size_t nb3 = cnb[3]; + + size_t nb00 = cnb0[0]; + size_t nb01 = cnb0[1]; + size_t nb02 = cnb0[2]; + size_t nb03 = cnb0[3]; + + size_t nb10 = cnb1[0]; + size_t nb11 = cnb1[1]; + size_t nb12 = cnb1[2]; + size_t nb13 = cnb1[3]; + + size_t s0 = nb0 / sizeof(dst_t); + size_t s1 = nb1 / sizeof(dst_t); + size_t s2 = nb2 / sizeof(dst_t); + size_t s3 = nb3 / sizeof(dst_t); + + size_t s10 = nb10 / sizeof(src1_t); + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); + + size_t s00 = nb00 / sizeof(src0_t); + size_t s01 = nb01 / sizeof(src0_t); + size_t s02 = nb02 / sizeof(src0_t); + size_t s03 = nb03 / sizeof(src0_t); + + GGML_UNUSED(s00); + + GGML_ASSERT(nb0 % sizeof(dst_t) == 0); + GGML_ASSERT(nb1 % sizeof(dst_t) == 0); + GGML_ASSERT(nb2 % sizeof(dst_t) == 0); + GGML_ASSERT(nb3 % sizeof(dst_t) == 0); + + GGML_ASSERT(nb00 % sizeof(src0_t) == 0); + GGML_ASSERT(nb01 % sizeof(src0_t) == 0); + GGML_ASSERT(nb02 % sizeof(src0_t) == 0); + GGML_ASSERT(nb03 % sizeof(src0_t) == 0); + + GGML_ASSERT(nb10 % sizeof(src1_t) == 0); + GGML_ASSERT(nb11 % sizeof(src1_t) == 0); + GGML_ASSERT(nb12 % sizeof(src1_t) == 0); + GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + + GGML_ASSERT(s0 == 1); + GGML_ASSERT(s10 == 1); + + const int block_size = 128; + + int64_t hne0 = std::max(ne0/2LL, 1LL); + + sycl::range<3> block_dims(1, 1, 1); + block_dims[2] = std::min(hne0, block_size); + block_dims[1] = std::min( + ne1, block_size / (unsigned int)block_dims[2]); + block_dims[0] = std::min( + std::min( + ne2 * ne3, block_size / (unsigned int)block_dims[2] / + (unsigned int)block_dims[1]), + 64U); + + sycl::range<3> block_nums( + (ne2 * ne3 + block_dims[0] - 1) / block_dims[0], + (ne1 + block_dims[1] - 1) / block_dims[1], + (hne0 + block_dims[2] - 1) / block_dims[2]); + + if (block_nums[0] > 65535) { + // this is the maximum number of blocks in z direction, fallback to 1D grid kernel + int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * + sycl::range<3>(1, 1, block_size), + sycl::range<3>(1, 1, block_size)), + [=](sycl::nd_item<3> item_ct1) { + k_bin_bcast_unravel( + src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, + ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02, + s03, s11, s12, s13, item_ct1); + }); + } + } else { + /* + DPCT1049:16: The work-group size passed to the SYCL kernel may + exceed the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if + needed. + */ + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, + ne2, ne3, ne10, ne11, ne12, ne13, + s1, s2, s3, s01, s02, s03, s11, s12, s13, + item_ct1); + }); + } + } + } +}; + +template +inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, + ggml_tensor * dst) { + dpct::queue_ptr main_stream = ctx.stream(); + GGML_TENSOR_BINARY_OP_LOCALS + + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + op()((const float *) src0->data, (const float *) src1->data, (float *) dst->data, ne00, ne01, ne02, ne03, ne10, + ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, + ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + op()((const sycl::half *) src0->data, (const sycl::half *) src1->data, (sycl::half *) dst->data, ne00, ne01, + ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, + nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), + main_stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + op()((const sycl::half *) src0->data, (const float *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02, + ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, + nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { + op()((const int32_t *) src0->data, (const int32_t *) src1->data, (int32_t *) dst->data, ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, + nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { + op()((const int16_t *) src0->data, (const int16_t *) src1->data, (int16_t *) dst->data, ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, + nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type), + ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } +} + +inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); +} + +inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); +} + +inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); +} + +inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); +} + +inline void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + ggml_sycl_op_bin_bcast>(ctx, dst, dst->src[0], dst); +} + + +void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_add(ctx, dst); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_sub(ctx, dst); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_mul(ctx, dst); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_div(ctx, dst); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_repeat(ctx, dst); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + diff --git a/ggml/src/ggml-sycl/binbcast.hpp b/ggml/src/ggml-sycl/binbcast.hpp new file mode 100644 index 00000000000..9cce0f053a5 --- /dev/null +++ b/ggml/src/ggml-sycl/binbcast.hpp @@ -0,0 +1,39 @@ +#ifndef GGML_SYCL_BINBCAST_HPP +#define GGML_SYCL_BINBCAST_HPP +#include "common.hpp" + + +static __dpct_inline__ float op_repeat(const float a, const float b) { + return b; + GGML_UNUSED(a); +} + +static __dpct_inline__ float op_add(const float a, const float b) { + return a + b; +} + +static __dpct_inline__ float op_sub(const float a, const float b) { + return a - b; +} + +static __dpct_inline__ float op_mul(const float a, const float b) { + return a * b; +} + +static __dpct_inline__ float op_div(const float a, const float b) { + return a / b; +} + +void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + + +#endif //GGML_SYCL_BINBCAST_HPP + diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 3e1ceeaa494..96becabc85a 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -494,286 +494,5 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor acc) { int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size); -template -static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13, - const sycl::nd_item<3> &item_ct1) { - const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1)); - const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + - item_ct1.get_local_id(0)) / - ne3; - const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + - item_ct1.get_local_id(0)) % - ne3; - - if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { - return; - } - - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; - - const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; - const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i3*s3 + i2*s2 + i1*s1; - - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; - dst_t * dst_row = dst + i_dst; - - for (int i0 = i0s; i0 < ne0; - i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) { - const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); - } -} - -template -static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13, - const sycl::nd_item<3> &item_ct1) { - - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - const int i3 = i/(ne2*ne1*ne0); - const int i2 = (i/(ne1*ne0)) % ne2; - const int i1 = (i/ne0) % ne1; - const int i0 = i % ne0; - - if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { - return; - } - - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; - - const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; - const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i3*s3 + i2*s2 + i1*s1; - - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; - dst_t * dst_row = dst + i_dst; - - const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); -} - - -template -struct bin_bcast_sycl { - template - void operator()(ggml_backend_sycl_context & ctx, - const struct ggml_tensor *src0, - const struct ggml_tensor *src1, struct ggml_tensor *dst, - const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd, - queue_ptr stream) { - - GGML_TENSOR_BINARY_OP_LOCALS - - int nr0 = ne10/ne0; - int nr1 = ne11/ne1; - int nr2 = ne12/ne2; - int nr3 = ne13/ne3; - - int nr[4] = { nr0, nr1, nr2, nr3 }; - - // collapse dimensions until first broadcast dimension - int64_t cne[] = {ne0, ne1, ne2, ne3}; - int64_t cne0[] = {ne00, ne01, ne02, ne03}; - int64_t cne1[] = {ne10, ne11, ne12, ne13}; - size_t cnb[] = {nb0, nb1, nb2, nb3}; - size_t cnb0[] = {nb00, nb01, nb02, nb03}; - size_t cnb1[] = {nb10, nb11, nb12, nb13}; - auto collapse = [](int64_t cne[]) { - cne[0] *= cne[1]; - cne[1] = cne[2]; - cne[2] = cne[3]; - cne[3] = 1; - }; - - auto collapse_nb = [](size_t cnb[], int64_t cne[]) { - cnb[1] *= cne[1]; - cnb[2] *= cne[2]; - cnb[3] *= cne[3]; - }; - - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { - for (int i = 0; i < 4; i++) { - if (nr[i] != 1) { - break; - } - if (i > 0) { - collapse_nb(cnb, cne); - collapse_nb(cnb0, cne0); - collapse_nb(cnb1, cne1); - collapse(cne); - collapse(cne0); - collapse(cne1); - } - } - } - { - int64_t ne0 = cne[0]; - int64_t ne1 = cne[1]; - int64_t ne2 = cne[2]; - int64_t ne3 = cne[3]; - - int64_t ne10 = cne1[0]; - int64_t ne11 = cne1[1]; - int64_t ne12 = cne1[2]; - int64_t ne13 = cne1[3]; - - size_t nb0 = cnb[0]; - size_t nb1 = cnb[1]; - size_t nb2 = cnb[2]; - size_t nb3 = cnb[3]; - - size_t nb00 = cnb0[0]; - size_t nb01 = cnb0[1]; - size_t nb02 = cnb0[2]; - size_t nb03 = cnb0[3]; - - size_t nb10 = cnb1[0]; - size_t nb11 = cnb1[1]; - size_t nb12 = cnb1[2]; - size_t nb13 = cnb1[3]; - - size_t s0 = nb0 / sizeof(dst_t); - size_t s1 = nb1 / sizeof(dst_t); - size_t s2 = nb2 / sizeof(dst_t); - size_t s3 = nb3 / sizeof(dst_t); - - size_t s10 = nb10 / sizeof(src1_t); - size_t s11 = nb11 / sizeof(src1_t); - size_t s12 = nb12 / sizeof(src1_t); - size_t s13 = nb13 / sizeof(src1_t); - - size_t s00 = nb00 / sizeof(src0_t); - size_t s01 = nb01 / sizeof(src0_t); - size_t s02 = nb02 / sizeof(src0_t); - size_t s03 = nb03 / sizeof(src0_t); - - GGML_UNUSED(s00); - - GGML_ASSERT(nb0 % sizeof(dst_t) == 0); - GGML_ASSERT(nb1 % sizeof(dst_t) == 0); - GGML_ASSERT(nb2 % sizeof(dst_t) == 0); - GGML_ASSERT(nb3 % sizeof(dst_t) == 0); - - GGML_ASSERT(nb00 % sizeof(src0_t) == 0); - GGML_ASSERT(nb01 % sizeof(src0_t) == 0); - GGML_ASSERT(nb02 % sizeof(src0_t) == 0); - GGML_ASSERT(nb03 % sizeof(src0_t) == 0); - - GGML_ASSERT(nb10 % sizeof(src1_t) == 0); - GGML_ASSERT(nb11 % sizeof(src1_t) == 0); - GGML_ASSERT(nb12 % sizeof(src1_t) == 0); - GGML_ASSERT(nb13 % sizeof(src1_t) == 0); - - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s10 == 1); - - const int block_size = 128; - - int64_t hne0 = std::max(ne0/2LL, 1LL); - - sycl::range<3> block_dims(1, 1, 1); - block_dims[2] = std::min(hne0, block_size); - block_dims[1] = std::min( - ne1, block_size / (unsigned int)block_dims[2]); - block_dims[0] = std::min( - std::min( - ne2 * ne3, block_size / (unsigned int)block_dims[2] / - (unsigned int)block_dims[1]), - 64U); - - sycl::range<3> block_nums( - (ne2 * ne3 + block_dims[0] - 1) / block_dims[0], - (ne1 + block_dims[1] - 1) / block_dims[1], - (hne0 + block_dims[2] - 1) / block_dims[2]); - - if (block_nums[0] > 65535) { - // this is the maximum number of blocks in z direction, fallback to 1D grid kernel - int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * - sycl::range<3>(1, 1, block_size), - sycl::range<3>(1, 1, block_size)), - [=](sycl::nd_item<3> item_ct1) { - k_bin_bcast_unravel( - src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02, - s03, s11, s12, s13, item_ct1); - }); - } - } else { - /* - DPCT1049:16: The work-group size passed to the SYCL kernel may - exceed the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if - needed. - */ - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, - ne2, ne3, ne10, ne11, ne12, ne13, - s1, s2, s3, s01, s02, s03, s11, s12, s13, - item_ct1); - }); - } - } - GGML_UNUSED(ctx); - } -}; - -template -inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst) { - dpct::queue_ptr main_stream = ctx.stream(); - - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - op()(ctx, src0, src1, dst, (const float *)src0->data, (const float *)src1->data, (float *)dst->data, main_stream); - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { - op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data, - (sycl::half *)dst->data, main_stream); - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { - op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data, (float *)dst->data, - main_stream); - } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { - op()(ctx, src0, src1, dst, (const int32_t *)src0->data, (const int32_t *)src1->data, (int32_t *)dst->data, - main_stream); - } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { - op()(ctx, src0, src1, dst, (const int16_t *)src0->data, (const int16_t *)src1->data, (int16_t *)dst->data, - main_stream); - } else { - fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, - ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); - GGML_ABORT("fatal error"); - } -} - bool gpu_has_xmx(sycl::device &dev); #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 0423305bb40..fc25d98ddff 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -1,4 +1,5 @@ #include "common.hpp" +#include "ggml.h" #include "element_wise.hpp" static void acc_f32(const float * x, const float * y, float * dst, const int ne, @@ -20,10 +21,11 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne, } } -static void gelu_f32(const float * x, float * dst, const int k, +template +static void gelu(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { - const float GELU_COEF_A = 0.044715f; - const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + const T GELU_COEF_A = static_cast(0.044715f); + const T SQRT_2_OVER_PI = static_cast(0.79788456080286535587989211986876f); const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -32,12 +34,13 @@ static void gelu_f32(const float * x, float * dst, const int k, } float xi = x[i]; - dst[i] = 0.5f * xi * - (1.0f + - sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi))); + dst[i] = static_cast(0.5f) * xi * + (static_cast(1.0f) + + sycl::tanh(SQRT_2_OVER_PI * xi * (static_cast(1.0f) + GELU_COEF_A * xi * xi))); } -static void silu_f32(const float * x, float * dst, const int k, +template +static void silu(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -45,10 +48,11 @@ static void silu_f32(const float * x, float * dst, const int k, if (i >= k) { return; } - dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i])); + dst[i] = x[i] / (static_cast(1.0f) + sycl::native::exp(-x[i])); } -static void gelu_quick_f32(const float *x, float *dst, int k, +template +static void gelu_quick(const T *x, T *dst, int k, const sycl::nd_item<3> &item_ct1) { const float GELU_QUICK_COEF = -1.702f; const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + @@ -56,20 +60,22 @@ static void gelu_quick_f32(const float *x, float *dst, int k, if (i >= k) { return; } - dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i]))); + dst[i] = x[i] * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i]))); } -static void tanh_f32(const float *x, float *dst, int k, +template +static void tanh(const T *x, T *dst, int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); if (i >= k) { return; } - dst[i] = sycl::tanh((float)(x[i])); + dst[i] = sycl::tanh((x[i])); } -static void relu_f32(const float * x, float * dst, const int k, +template +static void relu(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -77,10 +83,11 @@ static void relu_f32(const float * x, float * dst, const int k, if (i >= k) { return; } - dst[i] = sycl::fmax((float)(x[i]), (float)0); + dst[i] = sycl::fmax((x[i]), static_cast(0)); } -static void sigmoid_f32(const float * x, float * dst, const int k, +template +static void sigmoid(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -88,10 +95,11 @@ static void sigmoid_f32(const float * x, float * dst, const int k, if (i >= k) { return; } - dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i])); + dst[i] = 1.0f / (static_cast(1.0f) + sycl::native::exp(-x[i])); } -static void sqrt_f32(const float * x, float * dst, const int k, +template +static void sqrt(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -102,7 +110,8 @@ static void sqrt_f32(const float * x, float * dst, const int k, dst[i] = sycl::sqrt(x[i]); } -static void sin_f32(const float * x, float * dst, const int k, +template +static void sin(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -113,7 +122,8 @@ static void sin_f32(const float * x, float * dst, const int k, dst[i] = sycl::sin(x[i]); } -static void cos_f32(const float * x, float * dst, const int k, +template +static void cos(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -124,7 +134,8 @@ static void cos_f32(const float * x, float * dst, const int k, dst[i] = sycl::cos(x[i]); } -static void hardsigmoid_f32(const float * x, float * dst, const int k, +template +static void hardsigmoid(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -132,10 +143,11 @@ static void hardsigmoid_f32(const float * x, float * dst, const int k, if (i >= k) { return; } - dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f)); + dst[i] = sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x[i] + static_cast(3.0f)) / static_cast(6.0f))); } -static void hardswish_f32(const float * x, float * dst, const int k, +template +static void hardswish(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -143,10 +155,11 @@ static void hardswish_f32(const float * x, float * dst, const int k, if (i >= k) { return; } - dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f)); + dst[i] = x[i] * sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x[i] + static_cast(3.0f)) / static_cast(6.0f))); } -static void exp_f32(const float * x, float * dst, const int k, +template +static void exp(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -157,7 +170,8 @@ static void exp_f32(const float * x, float * dst, const int k, dst[i] = sycl::exp(x[i]); } -static void log_f32(const float * x, float * dst, const int k, +template +static void log(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -165,15 +179,16 @@ static void log_f32(const float * x, float * dst, const int k, if (i >= k) { return; } - float xi = x[i]; + T xi = x[i]; if (xi <= 0) { - dst[i] = -INFINITY; + dst[i] = neg_infinity(); } else { dst[i] = sycl::log(xi); } } -static void neg_f32(const float * x, float * dst, const int k, +template +static void neg(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -184,7 +199,8 @@ static void neg_f32(const float * x, float * dst, const int k, dst[i] = -x[i]; } -static void step_f32(const float * x, float * dst, const int k, +template +static void step(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -192,21 +208,23 @@ static void step_f32(const float * x, float * dst, const int k, if (i >= k) { return; } - dst[i] = x[i] > 0.0f; + dst[i] = x[i] > static_cast(0.0f); } -static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope, +template +static void leaky_relu(const T *x, T *dst, const int k, const float negative_slope, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); if (i >= k) { return; } - dst[i] = sycl::fmax((float)(x[i]), (float)0) + - sycl::fmin((float)(x[i]), 0.0f) * negative_slope; + dst[i] = sycl::fmax((x[i]), static_cast(0)) + + sycl::fmin((x[i]), static_cast(0.0f)) * negative_slope; } -static void sqr_f32(const float * x, float * dst, const int k, +template +static void sqr(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -217,7 +235,8 @@ static void sqr_f32(const float * x, float * dst, const int k, dst[i] = x[i] * x[i]; } -static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01, +template +static void upscale(const T *x, T *dst, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int ne13, const float sf0, const float sf1, const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) { @@ -237,10 +256,11 @@ static void upscale_f32(const float *x, float *dst, const int nb00, const int n int i02 = i12 / sf2; int i03 = i13 / sf3; - dst[index] = *(const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); + dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); } -static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02, +template +static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02, const sycl::nd_item<3> &item_ct1) { int nidx = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2); @@ -256,11 +276,23 @@ static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, item_ct1.get_group(0) * ne00 * ne01; dst[offset_dst] = x[offset_src]; } else { - dst[offset_dst] = 0.0f; + dst[offset_dst] = static_cast(0.0f); } } +template +static void clamp(const T * x, T * dst, const float min, const float max, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + + dst[i] = x[i] < static_cast(min) ? static_cast(min) : (x[i] > static_cast(max) ? static_cast(max) : x[i]); +} static void acc_f32_sycl(const float *x, const float *y, float *dst, const int n_elements, const int ne10, const int ne11, @@ -277,7 +309,8 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst, }); } -static void gelu_f32_sycl(const float *x, float *dst, const int k, +template +static void gelu_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; stream->parallel_for( @@ -285,11 +318,12 @@ static void gelu_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - gelu_f32(x, dst, k, item_ct1); + gelu(x, dst, k, item_ct1); }); } -static void silu_f32_sycl(const float *x, float *dst, const int k, +template +static void silu_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE; stream->parallel_for( @@ -297,11 +331,12 @@ static void silu_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - silu_f32(x, dst, k, item_ct1); + silu(x, dst, k, item_ct1); }); } -static void gelu_quick_f32_sycl(const float *x, float *dst, const int k, +template +static void gelu_quick_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; stream->parallel_for( @@ -309,11 +344,12 @@ static void gelu_quick_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - gelu_quick_f32(x, dst, k, item_ct1); + gelu_quick(x, dst, k, item_ct1); }); } -static void tanh_f32_sycl(const float *x, float *dst, const int k, +template +static void tanh_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE; stream->parallel_for( @@ -321,11 +357,12 @@ static void tanh_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - tanh_f32(x, dst, k, item_ct1); + tanh(x, dst, k, item_ct1); }); } -static void relu_f32_sycl(const float *x, float *dst, const int k, +template +static void relu_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; stream->parallel_for( @@ -333,11 +370,12 @@ static void relu_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - relu_f32(x, dst, k, item_ct1); + relu(x, dst, k, item_ct1); }); } -static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k, +template +static void hardsigmoid_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE; stream->parallel_for( @@ -345,11 +383,12 @@ static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - hardsigmoid_f32(x, dst, k, item_ct1); + hardsigmoid(x, dst, k, item_ct1); }); } -static void hardswish_f32_sycl(const float *x, float *dst, const int k, +template +static void hardswish_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE; stream->parallel_for( @@ -357,11 +396,12 @@ static void hardswish_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - hardswish_f32(x, dst, k, item_ct1); + hardswish(x, dst, k, item_ct1); }); } -static void exp_f32_sycl(const float *x, float *dst, const int k, +template +static void exp_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; stream->parallel_for( @@ -369,11 +409,12 @@ static void exp_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - exp_f32(x, dst, k, item_ct1); + exp(x, dst, k, item_ct1); }); } -static void log_f32_sycl(const float *x, float *dst, const int k, +template +static void log_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; stream->parallel_for( @@ -381,11 +422,12 @@ static void log_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - log_f32(x, dst, k, item_ct1); + log(x, dst, k, item_ct1); }); } -static void neg_f32_sycl(const float *x, float *dst, const int k, +template +static void neg_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; stream->parallel_for( @@ -393,11 +435,12 @@ static void neg_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - neg_f32(x, dst, k, item_ct1); + neg(x, dst, k, item_ct1); }); } -static void step_f32_sycl(const float *x, float *dst, const int k, +template +static void step_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; stream->parallel_for( @@ -405,11 +448,12 @@ static void step_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - step_f32(x, dst, k, item_ct1); + step(x, dst, k, item_ct1); }); } -static void sigmoid_f32_sycl(const float *x, float *dst, const int k, +template +static void sigmoid_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE; stream->parallel_for( @@ -417,11 +461,12 @@ static void sigmoid_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - sigmoid_f32(x, dst, k, item_ct1); + sigmoid(x, dst, k, item_ct1); }); } -static void sqrt_f32_sycl(const float *x, float *dst, const int k, +template +static void sqrt_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE; stream->parallel_for( @@ -429,11 +474,12 @@ static void sqrt_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - sqrt_f32(x, dst, k, item_ct1); + sqrt(x, dst, k, item_ct1); }); } -static void sin_f32_sycl(const float *x, float *dst, const int k, +template +static void sin_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; stream->parallel_for( @@ -441,11 +487,12 @@ static void sin_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - sin_f32(x, dst, k, item_ct1); + sin(x, dst, k, item_ct1); }); } -static void cos_f32_sycl(const float *x, float *dst, const int k, +template +static void cos_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; stream->parallel_for( @@ -453,11 +500,12 @@ static void cos_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - cos_f32(x, dst, k, item_ct1); + cos(x, dst, k, item_ct1); }); } -static void leaky_relu_f32_sycl(const float *x, float *dst, const int k, +template +static void leaky_relu_sycl(const T *x, T *dst, const int k, const float negative_slope, queue_ptr stream) { const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; @@ -466,11 +514,12 @@ static void leaky_relu_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - leaky_relu_f32(x, dst, k, negative_slope, item_ct1); + leaky_relu(x, dst, k, negative_slope, item_ct1); }); } -static void sqr_f32_sycl(const float *x, float *dst, const int k, +template +static void sqr_sycl(const T *x, T *dst, const int k, queue_ptr stream) { const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE; stream->parallel_for( @@ -478,11 +527,12 @@ static void sqr_f32_sycl(const float *x, float *dst, const int k, sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - sqr_f32(x, dst, k, item_ct1); + sqr(x, dst, k, item_ct1); }); } -static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01, +template +static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int ne13, const float sf0, const float sf1, const float sf2, const float sf3, queue_ptr stream) { @@ -492,11 +542,12 @@ static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const i stream->parallel_for( sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { - upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1); + upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1); }); } -static void pad_f32_sycl(const float *x, float *dst, const int ne00, +template +static void pad_sycl(const T *x, T *dst, const int ne00, const int ne01, const int ne02, const int ne0, const int ne1, const int ne2, queue_ptr stream) { int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; @@ -505,252 +556,688 @@ static void pad_f32_sycl(const float *x, float *dst, const int ne00, sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - pad_f32(x, dst, ne0, ne00, ne01, ne02, item_ct1); + pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); }); } -inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { +template +static void clamp_sycl(const T *x, T *dst, const float min, + const float max, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + clamp(x, dst, min, max, k, item_ct1); + }); +} +inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - - silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - log_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - - sigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - - sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - sin_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - cos_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - step_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } - -inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); float negative_slope; memcpy(&negative_slope, dst->op_params, sizeof(float)); - - leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), negative_slope, main_stream); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + #if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - const float sf0 = (float)dst->ne[0]/dst->src[0]->ne[0]; - const float sf1 = (float)dst->ne[1]/dst->src[0]->ne[1]; - const float sf2 = (float)dst->ne[2]/dst->src[0]->ne[2]; - const float sf3 = (float)dst->ne[3]/dst->src[0]->ne[3]; + const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0]; + const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1]; + const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2]; + const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3]; + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], + dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, + main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], + dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, + main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } +} - upscale_f32_sycl(src0_dd, dst_dd, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], dst->src[0]->nb[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, - main_stream); +inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0], + dst->ne[1], dst->ne[2], main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0], + dst->ne[1], dst->ne[2], main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } -inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { +inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined(GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - pad_f32_sycl(src0_dd, dst_dd, - dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], - dst->ne[0], dst->ne[1], dst->ne[2], main_stream); + float min; + float max; + memcpy(&min, dst->op_params, sizeof(float)); + memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); + + switch (dst->type) { +#if defined(GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + break; + } } inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { @@ -773,170 +1260,131 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream); } -inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - - ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); -} - -inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - - ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); -} - -inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - - ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); -} - -inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - - ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); -} - void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_sqrt(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_sin(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_cos(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_acc(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_gelu(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_silu(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_gelu_quick(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_tanh(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_relu(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_sigmoid(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_hardsigmoid(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_hardswish(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_exp(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_log(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_neg(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_step(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_leaky_relu(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_sqr(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_upscale(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); ggml_sycl_op_pad(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } - - -void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_add(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_sub(ctx, dst); +void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + ggml_sycl_op_clamp(ctx, dst); GGML_SYCL_DEBUG("call %s done\n", __func__); } -void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_mul(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_div(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index 46443264505..e623cb56f76 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -2,29 +2,28 @@ #define GGML_SYCL_ELEMENTWISE_HPP #include "common.hpp" +#include "ggml.h" +#include -static __dpct_inline__ float op_repeat(const float a, const float b) { - return b; - GGML_UNUSED(a); +template +T neg_infinity() { + return -std::numeric_limits::infinity(); } -static __dpct_inline__ float op_add(const float a, const float b) { - return a + b; +template +struct typed_data { + const T * src; + T * dst; +}; + +template +typed_data cast_data(ggml_tensor * dst) { + return { + /* .src = */ static_cast(dst->src[0]->data), + /* .dst = */ static_cast(dst->data) + }; } -static __dpct_inline__ float op_sub(const float a, const float b) { - return a - b; -} - -static __dpct_inline__ float op_mul(const float a, const float b) { - return a * b; -} - -static __dpct_inline__ float op_div(const float a, const float b) { - return a / b; -} - - void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst); @@ -65,12 +64,7 @@ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst); -void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst); - -void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst); - -void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst); - -void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst); #endif // GGML_SYCL_ELEMENTWISE_HPP + diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index dff9f8d4c4a..8081a77b74f 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -372,6 +372,8 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer, auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue()); SYCL_CHECK( CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw())); + // Note: Use host buffer to save the data from mmap(), then copy to device. It's workaround for mmap() issue on PVC GPU. + // This function will be called during load model from disk. Use memory buffer replace dynamic won't save more time and brings potential memory leak risk here. char* host_buf = (char*)malloc(size); memcpy(host_buf, data, size); SYCL_CHECK( @@ -1615,17 +1617,6 @@ static void scale_f32(const float * x, float * dst, const float scale, const int dst[i] = scale * x[i]; } -static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - - dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); -} template static void pool2d_nchw_kernel( @@ -1766,18 +1757,6 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale, }); } -static void clamp_f32_sycl(const float *x, float *dst, const float min, - const float max, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - clamp_f32(x, dst, min, max, k, item_ct1); - }); -} static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols, const int nrows, queue_ptr stream) { @@ -1988,11 +1967,6 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - ggml_sycl_op_bin_bcast>(ctx, dst, dst->src[0], dst); -} - - inline void ggml_sycl_op_mul_mat_sycl( ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, @@ -2256,26 +2230,6 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor *dst SYCL_CHECK(0); } -inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); - - float min; - float max; - memcpy(&min, dst->op_params, sizeof(float)); - memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); - - clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(dst->src[0]), ctx.stream()); - /* - DPCT1010:88: SYCL uses exceptions to report errors and does not use the - error codes. The call was replaced with 0. You need to rewrite this code. - */ - SYCL_CHECK(0); -} - static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) { static bool peer_access_enabled = false; @@ -2641,12 +2595,6 @@ catch (sycl::exception const &exc) { } -static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_repeat(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); ggml_sycl_op_get_rows(ctx, dst); @@ -3216,19 +3164,10 @@ static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) ggml_sycl_op_scale(ctx, dst); } -static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_op_clamp(ctx, dst); -} - static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_diag_mask_inf(ctx, dst); } -static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented - ggml_sycl_op_rope(ctx, dst); -} - static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_pool2d(ctx, dst); } @@ -3698,7 +3637,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_ #ifdef GGML_SYCL_GRAPH if (!g_ggml_sycl_disable_graph) { - if (!sycl_ctx->exec_graph && !dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph)) { + const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph); + if (!graph_support) { GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device); ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); return GGML_STATUS_SUCCESS; @@ -3709,8 +3649,10 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_ ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); model_sycl_graph.end_recording(); - if (!sycl_ctx->exec_graph) { - auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}}); + const bool graph_update_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph); + if (!sycl_ctx->exec_graph || !graph_update_support) { + auto exec_graph = graph_update_support ? model_sycl_graph.finalize(sycl_ex::property::graph::updatable{}) : + model_sycl_graph.finalize(); sycl_ctx->exec_graph = std::make_unique< sycl_ex::command_graph>(exec_graph); } else { @@ -3898,7 +3840,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: - return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32); +#if defined (GGML_SYCL_F16) + return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type); +#else + return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) && (op->type == op->src[0]->type); +#endif default: return false; } @@ -4010,7 +3956,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ARGMAX: case GGML_OP_NONE: case GGML_OP_RESHAPE: - case GGML_OP_REPEAT: case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: @@ -4020,13 +3965,19 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_REPEAT: + return true; case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_LOG: - return (op->src[0]->type == GGML_TYPE_F32); +#if defined (GGML_SYCL_F16) + return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type)); +#else + return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type); +#endif case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: @@ -4042,23 +3993,21 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ROPE: { const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { + // mode is not used as a bitmask in practice, the various rope type modes are independent implementations + if (mode == GGML_ROPE_TYPE_MROPE) { return false; } - if (mode & GGML_ROPE_TYPE_VISION) { - return false; - } - return ggml_is_contiguous(op->src[0]); + return true; } case GGML_OP_IM2COL: - // TODO: add support for the new F32 operations - return op->src[0]->type == GGML_TYPE_F16; + return true; + case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_POOL_2D: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: - case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_LEAKY_RELU: case GGML_OP_TIMESTEP_EMBEDDING: diff --git a/ggml/src/ggml-sycl/im2col.cpp b/ggml/src/ggml-sycl/im2col.cpp index 009b42035d0..aa19c2527dc 100644 --- a/ggml/src/ggml-sycl/im2col.cpp +++ b/ggml/src/ggml-sycl/im2col.cpp @@ -12,110 +12,125 @@ #include "im2col.hpp" +#include +#include // For std::is_same_v + +#include "ggml.h" + template -static void im2col_kernel( - const float *x, T *dst, int64_t batch_offset, int64_t offset_delta, - int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, - int64_t pelements, int64_t CHW, int s0, int s1, int p0, int p1, int d0, int d1, - const sycl::nd_item<3> &item_ct1) { +static void im2col_kernel(const float * x, T * dst, int64_t batch_offset, int64_t offset_delta, int64_t IC, int64_t IW, + int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW, + int s0, int s1, int p0, int p1, int d0, int d1, const sycl::nd_item<3> & item_ct1) { const int64_t work_group_size = item_ct1.get_local_range(2); - const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2); + const int64_t global_id = item_ct1.get_local_id(2) + (work_group_size * item_ct1.get_group(2)); // make each work-item deal with more elements since sycl global range can not exceed max int - for (int64_t i = global_id; i < pelements; i += work_group_size * item_ct1.get_group_range(2)) { - + for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ct1.get_group_range(2))) { const int64_t ksize = OW * (KH > 1 ? KW : 1); - const int64_t kx = i / ksize; - const int64_t kd = kx * ksize; - const int64_t ky = (i - kd) / OW; - const int64_t ix = i % OW; - - const int64_t oh = item_ct1.get_group(1); - const int64_t batch = item_ct1.get_group(0) / IC; - const int64_t ic = item_ct1.get_group(0) % IC; - - const int64_t iiw = ix * s0 + kx * d0 - p0; - const int64_t iih = oh * s1 + ky * d1 - p1; - - const int64_t offset_dst = - ((batch * OH + oh) * OW + ix) * CHW + - (ic * (KW * KH) + ky * KW + kx); - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = - sycl::vec(0.0f) - .convert()[0]; - } else { - const int64_t offset_src = ic * offset_delta + batch * batch_offset; - dst[offset_dst] = - sycl::vec(x[offset_src + iih * IW + iiw]) - .convert()[0]; + const int64_t kx = i / ksize; + const int64_t kd = kx * ksize; + const int64_t ky = (i - kd) / OW; + const int64_t ix = i % OW; + + const int64_t oh = item_ct1.get_group(1); + const int64_t batch = item_ct1.get_group(0) / IC; + const int64_t ic = item_ct1.get_group(0) % IC; + + const int64_t iiw = (ix * s0) + (kx * d0) - p0; + const int64_t iih = (oh * s1) + (ky * d1) - p1; + + const int64_t offset_dst = (((batch * OH + oh) * OW + ix) * CHW) + (ic * (KW * KH) + ky * KW + kx); + + const int64_t offset_src_base = (ic * offset_delta) + (batch * batch_offset); + const int64_t offset_src = offset_src_base + (iih * IW) + iiw; + + const bool out_of_bounds = (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW); + const float src_val = out_of_bounds ? 0.0f : x[offset_src]; + + if constexpr (std::is_same_v) { + dst[offset_dst] = sycl::half(src_val); + } else if constexpr (std::is_same_v) { + dst[offset_dst] = src_val; } } } template -static void im2col_sycl( - const float *x, T *dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, - int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, - int s0, int s1, int p0, int p1, int d0, int d1, - queue_ptr stream) { +static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, + int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, + int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { const int64_t parallel_elements = OW * KW * KH; - const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; + const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; // decrease global range when it exceeds the max int int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE); + sycl::range<3> block_nums(batch * IC, OH, num_blocks); sycl::range<3> local_range(1, 1, local_size); - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(block_nums * local_range, local_range), - [=](sycl::nd_item<3> item_ct1) { - im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, - parallel_elements, (IC * KH * KW), s0, s1, p0, - p1, d0, d1, item_ct1); - }); + const int64_t CHW = IC * KH * KW; + + stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) { + im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1, + p0, p1, d0, d1, item_ct1); + }); +} + +static void im2col_sycl_f16(const float * x, sycl::half * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, + int64_t KW, int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, + int64_t offset_delta, int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { + if (!stream->get_device().has(sycl::aspect::fp16)) { + throw sycl::exception(sycl::make_error_code(sycl::errc::kernel_not_supported), + "Device does not support half precision (fp16) operations!"); } + im2col_sycl_internal(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, + p1, d0, d1, stream); +} + +static void im2col_sycl_f32(const float * x, float * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, + int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, int s0, + int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { + im2col_sycl_internal(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, + d0, d1, stream); } -void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { +void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; + const int32_t s0 = ((const int32_t *) (dst->op_params))[0]; + const int32_t s1 = ((const int32_t *) (dst->op_params))[1]; + const int32_t p0 = ((const int32_t *) (dst->op_params))[2]; + const int32_t p1 = ((const int32_t *) (dst->op_params))[3]; + const int32_t d0 = ((const int32_t *) (dst->op_params))[4]; + const int32_t d1 = ((const int32_t *) (dst->op_params))[5]; - const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; + const bool is_2D = ((const int32_t *) (dst->op_params))[6] == 1; const int64_t IC = src1->ne[is_2D ? 2 : 1]; const int64_t IH = is_2D ? src1->ne[1] : 1; - const int64_t IW = src1->ne[0]; + const int64_t IW = src1->ne[0]; const int64_t KH = is_2D ? src0->ne[1] : 1; - const int64_t KW = src0->ne[0]; + const int64_t KW = src0->ne[0]; const int64_t OH = is_2D ? dst->ne[2] : 1; - const int64_t OW = dst->ne[1]; + const int64_t OW = dst->ne[1]; + + const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / sizeof(float); + const int64_t batch = src1->ne[is_2D ? 3 : 2]; + const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / sizeof(float); - const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 - const int64_t batch = src1->ne[3]; - const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 + queue_ptr stream = ctx.stream(); if (dst->type == GGML_TYPE_F16) { - im2col_sycl((const float *) src1->data, (sycl::half *)dst->data, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, ctx.stream()); + im2col_sycl_f16((const float *) src1->data, (sycl::half *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch, + batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); } else { - im2col_sycl((const float *) src1->data, (float *)dst->data, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, ctx.stream()); + im2col_sycl_f32((const float *) src1->data, (float *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch, + batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); } } diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp index bbcb356e979..4e276d3b62e 100644 --- a/ggml/src/ggml-sycl/rope.cpp +++ b/ggml/src/ggml-sycl/rope.cpp @@ -1,9 +1,15 @@ #include "rope.hpp" +#include "ggml-sycl/common.hpp" +#include "ggml.h" struct rope_corr_dims { float v[2]; }; +struct mrope_sections { + int v[4]; +}; + static float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low); return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y)); @@ -28,23 +34,21 @@ static void rope_yarn( *sin_theta = sycl::sin(theta) * mscale; } -template -static void rope_norm( - const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, - const sycl::nd_item<3> &item_ct1) { - const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1)); +template +static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * pos, float freq_scale, float ext_factor, float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, + const sycl::nd_item<3> & item_ct1) { + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1)); if (i0 >= ne0) { return; } - const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); + const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); if (i0 >= n_dims) { - const int i = row*ne0 + i0; + const int i = row * ne0 + i0; dst[i + 0] = x[i + 0]; dst[i + 1] = x[i + 1]; @@ -52,42 +56,43 @@ static void rope_norm( return; } - const int i = row*ne0 + i0; - const int i2 = row/p_delta_rows; + const int row0 = row % ne1; + const int channel0 = row / ne1; - const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f); + const int i = row * ne0 + i0; + const int i2 = channel0 * s2 + row0 * s1 + i0; - const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); + + const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; float cos_theta; float sin_theta; - rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - const float x0 = x[i + 0]; - const float x1 = x[i + 1]; + const float x0 = x[i2 + 0]; + const float x1 = x[i2 + 1]; - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + 1] = x0*sin_theta + x1*cos_theta; + dst[i + 0] = x0 * cos_theta - x1 * sin_theta; + dst[i + 1] = x0 * sin_theta + x1 * cos_theta; } -template -static void rope_neox( - const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, - const sycl::nd_item<3> &item_ct1) { - const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1)); +template +static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, + const sycl::nd_item<3> & item_ct1) { + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1)); if (i0 >= ne0) { return; } - const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); + const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); if (i0 >= n_dims) { - const int i = row*ne0 + i0; + const int i = row * ne0 + i0; dst[i + 0] = x[i + 0]; dst[i + 1] = x[i + 1]; @@ -95,38 +100,83 @@ static void rope_neox( return; } - const int i = row*ne0 + i0/2; - const int i2 = row/p_delta_rows; + const int row0 = row % ne1; + const int channel0 = row / ne1; + + const int i = row * ne0 + i0 / 2; + const int i2 = channel0 * s2 + row0 * s1 + i0 / 2; - const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f); + const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); - const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; float cos_theta; float sin_theta; - rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - const float x0 = x[i + 0]; - const float x1 = x[i + n_dims/2]; + const float x0 = x[i2 + 0]; + const float x1 = x[i2 + n_dims / 2]; - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; + dst[i + 0] = x0 * cos_theta - x1 * sin_theta; + dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta; +} + +template +static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, + const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale, + const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float * freq_factors, const mrope_sections sections, + const sycl::nd_item<3> & item_ct1) { + // get index pos + const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1)); + if (i0 >= ne0) { + return; + } + const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; + const int idst = (row_dst * ne0) + (i0 / 2); + const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2); + + const int sect_dims = sections.v[0] + sections.v[1]; + const int sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0f; + if (sector < sections.v[0]) { + const int p = sector; + theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p); + } else { + // Simplified from CUDA backend code: if (sector >= sections.v[0] && sector < sec_w) which is just sector >= sections.v[0] + const int p = sector - sections.v[0]; + theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p); + } + + const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; + float cos_theta; + float sin_theta; + rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims]; + + // store results in dst + dst[idst + 0] = x0 * cos_theta - x1 * sin_theta; + dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta; } template -static void rope_norm_sycl( - const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) { +static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, + const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base, + const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, + const float * freq_factors, queue_ptr stream) { GGML_ASSERT(ne0 % 2 == 0); const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); + const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE); const sycl::range<3> block_nums(1, num_blocks_x, nr); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); if (freq_factors == nullptr) { /* @@ -134,79 +184,102 @@ static void rope_norm_sycl( the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. */ - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_norm(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, - ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, - item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + rope_norm(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, item_ct1); + }); } else { /* DPCT1049:41: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. */ - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_norm(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, - ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, - item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + rope_norm(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, item_ct1); + }); } } template -static void rope_neox_sycl( - const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) { +static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, + const int n_dims, const int nr, const int32_t * pos, const float freq_scale, + const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) { GGML_ASSERT(ne0 % 2 == 0); const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); + const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE); const sycl::range<3> block_nums(1, num_blocks_x, nr); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); if (freq_factors == nullptr) { - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_neox(x, dst, ne0, n_dims, pos, freq_scale, - p_delta_rows, ext_factor, attn_factor, - corr_dims, theta_scale, freq_factors, - item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, item_ct1); + }); } else { - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_neox(x, dst, ne0, n_dims, pos, freq_scale, - p_delta_rows, ext_factor, attn_factor, - corr_dims, theta_scale, freq_factors, - item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, item_ct1); + }); } } -void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { +// rope vision +template +static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, + const size_t s2, const int n_dims, const int nr, const int32_t * pos, + const float freq_scale, const float freq_base, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors, + const mrope_sections sections, queue_ptr stream) { + GGML_ASSERT(ne0 % 2 == 0); + const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int n_blocks_y = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE); + const sycl::range<3> grid_dims(1, n_blocks_y, nr); + const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims); + + const float theta_scale = std::pow(freq_base, -2.0f / n_dims); + // Add FP16 capability check if T could be sycl::half + if constexpr (std::is_same_v) { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + } + // launch kernel + if (freq_factors == nullptr) { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { + rope_vision(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, sections, item_ct1); + }); + } else { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { + rope_vision(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, sections, item_ct1); + }); + } +} + +inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(dst->src[0]->type == dst->type); - - const int64_t ne00 = dst->src[0]->ne[0]; - const int64_t ne01 = dst->src[0]->ne[1]; + const int64_t ne00 = dst->src[0]->ne[0]; // head dims + const int64_t ne01 = dst->src[0]->ne[1]; // num heads + const int64_t ne02 = dst->src[0]->ne[2]; // num heads const int64_t nr = ggml_nrows(dst->src[0]); + const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type); + const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type); + + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; //const int n_ctx = ((int32_t *) dst->op_params)[3]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + mrope_sections sections; // RoPE alteration for extended context float freq_base; @@ -222,8 +295,10 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4); const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; const int32_t * pos = (const int32_t *) dst->src[1]->data; @@ -240,32 +315,48 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { // compute if (is_neox) { + GGML_SYCL_DEBUG("%s: neox path\n", __func__); if (dst->src[0]->type == GGML_TYPE_F32) { - rope_neox_sycl( - (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, main_stream - ); + rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr, + pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream); } else if (dst->src[0]->type == GGML_TYPE_F16) { - rope_neox_sycl( - (const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, main_stream - ); + rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02, + n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + main_stream); } else { GGML_ABORT("fatal error"); } + } else if (is_vision) { + GGML_SYCL_DEBUG("%s: vision path\n", __func__); + if (dst->src[0]->type == GGML_TYPE_F16) { + rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01, + s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, + freq_factors, sections, main_stream); + } else if (dst->src[0]->type == GGML_TYPE_F32) { + rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims, + nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, + main_stream); + } else { + GGML_ABORT("Fatal error: Tensor type unsupported!"); + } } else { + GGML_SYCL_DEBUG("%s: norm path\n", __func__); if (dst->src[0]->type == GGML_TYPE_F32) { - rope_norm_sycl( - (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, main_stream - ); + rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr, + pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream); } else if (dst->src[0]->type == GGML_TYPE_F16) { - rope_norm_sycl( - (const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, main_stream - ); + rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02, + n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + main_stream); } else { GGML_ABORT("fatal error"); } } } + +void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_rope(ctx, dst); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + diff --git a/ggml/src/ggml-sycl/rope.hpp b/ggml/src/ggml-sycl/rope.hpp index a399bddb8a0..8c7141aac5c 100644 --- a/ggml/src/ggml-sycl/rope.hpp +++ b/ggml/src/ggml-sycl/rope.hpp @@ -15,6 +15,6 @@ #include "common.hpp" -void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst); +void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst); #endif // GGML_SYCL_ROPE_HPP diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index e3c59b75fd5..9d028f718d0 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -23,49 +23,35 @@ if (Vulkan_FOUND) ../../include/ggml-vulkan.h ) - if(NOT DEFINED GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) - # Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) - - if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*") - message(STATUS "GL_KHR_cooperative_matrix not supported by glslc") - set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF CACHE INTERNAL "Whether coopmat is supported by glslc") - else() - message(STATUS "GL_KHR_cooperative_matrix supported by glslc") - set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON CACHE INTERNAL "Whether coopmat is supported by glslc") - endif() - else() - message(STATUS "GL_KHR_cooperative_matrix support already defined: ${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}") - endif() + # Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported. + # If it's not, there will be an error to stderr. + # If it's supported, set a define to indicate that we should compile those shaders + execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error) - if(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*") + message(STATUS "GL_KHR_cooperative_matrix not supported by glslc") + set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF) + else() + message(STATUS "GL_KHR_cooperative_matrix supported by glslc") + set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON) add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) endif() - if(NOT DEFINED GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) - - if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*") - message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc") - set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF CACHE INTERNAL "Whether coopmat2 is supported by glslc") - else() - message(STATUS "GL_NV_cooperative_matrix2 supported by glslc") - set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON CACHE INTERNAL "Whether coopmat2 is supported by glslc") - endif() - else() - message(STATUS "GL_NV_cooperative_matrix2 support already defined: ${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}") - endif() + # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported. + # If it's not, there will be an error to stderr. + # If it's supported, set a define to indicate that we should compile those shaders + execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error) - if(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*") + message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc") + set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF) + else() + message(STATUS "GL_NV_cooperative_matrix2 supported by glslc") + set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON) add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) endif() @@ -78,8 +64,10 @@ if (Vulkan_FOUND) if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*") message(STATUS "GL_EXT_integer_dot_product not supported by glslc") + set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT OFF) else() message(STATUS "GL_EXT_integer_dot_product supported by glslc") + set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT ON) add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) endif() @@ -153,6 +141,7 @@ if (Vulkan_FOUND) -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT} -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT} + -DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT} BUILD_COMMAND ${CMAKE_COMMAND} --build . INSTALL_COMMAND ${CMAKE_COMMAND} --install . INSTALL_DIR ${CMAKE_BINARY_DIR} diff --git a/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in b/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in index b6af747a500..2d8a85696d3 100644 --- a/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +++ b/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in @@ -4,8 +4,8 @@ set(CMAKE_CXX_FLAGS -O2) set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY NEVER) set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE NEVER) -set(CMAKE_C_COMPILER @HOST_C_COMPILER@) -set(CMAKE_CXX_COMPILER @HOST_CXX_COMPILER@) +set(CMAKE_C_COMPILER "@HOST_C_COMPILER@") +set(CMAKE_CXX_COMPILER "@HOST_CXX_COMPILER@") set(CMAKE_RUNTIME_OUTPUT_DIRECTORY @CMAKE_RUNTIME_OUTPUT_DIRECTORY@) if("@CMAKE_C_COMPILER_ID@" STREQUAL "MSVC") diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ee0969fe189..c0bdb9e17a7 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -24,6 +24,28 @@ #include #include +#if defined(_MSC_VER) +# define NOMINMAX 1 +# include +# define YIELD() YieldProcessor() +#elif defined(__clang__) || defined(__GNUC__) +# if defined(__x86_64__) ||defined(__i386__) +# include +# define YIELD() _mm_pause() +# elif defined(__arm__) || defined(__aarch64__) +# if defined(__clang__) +# include +# define YIELD() __yield() +# else +# define YIELD() asm volatile("yield") +# endif +# endif +#endif + +#if !defined(YIELD) +#define YIELD() +#endif + #include "ggml-impl.h" #include "ggml-backend-impl.h" @@ -31,6 +53,7 @@ #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1)) #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) +static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define VK_VENDOR_ID_AMD 0x1002 #define VK_VENDOR_ID_APPLE 0x106b @@ -223,6 +246,7 @@ struct vk_device_struct { bool pipeline_robustness; vk::Device device; uint32_t vendor_id; + vk::DriverId driver_id; vk_device_architecture architecture; vk_queue compute_queue; vk_queue transfer_queue; @@ -352,6 +376,7 @@ struct vk_device_struct { vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_split_k_reduce; std::unordered_map pipelines; std::unordered_map pipeline_descriptor_set_requirements; @@ -501,6 +526,10 @@ struct vk_flash_attn_push_constants { uint32_t n_head_log2; float m0; float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; }; struct vk_op_push_constants { @@ -781,7 +810,8 @@ struct ggml_backend_vk_context { ggml_vk_garbage_collector gc; size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k; vk_buffer prealloc_x, prealloc_y, prealloc_split_k; - vk::Fence fence; + vk::Fence fence, almost_ready_fence; + bool almost_ready_fence_pending {}; vk_buffer buffer_pool[MAX_VK_BUFFERS]; @@ -872,6 +902,39 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx static void ggml_backend_vk_free(ggml_backend_t backend); +// Wait for ctx->fence to be signaled. +static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) { + // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep + // during this wait. + if (ctx->almost_ready_fence_pending) { + VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), "almost_ready_fence"); + ctx->device->device.resetFences({ ctx->almost_ready_fence }); + ctx->almost_ready_fence_pending = false; + } + + // Spin (w/pause) waiting for the graph to finish executing. + vk::Result result; + while ((result = ctx->device->device.getFenceStatus(ctx->fence)) != vk::Result::eSuccess) { + if (result != vk::Result::eNotReady) { + fprintf(stderr, "ggml_vulkan: error %s at %s:%d\n", to_string(result).c_str(), __FILE__, __LINE__); + exit(1); + } + for (uint32_t i = 0; i < 100; ++i) { + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + } + } + ctx->device->device.resetFences({ ctx->fence }); +} + // variables to track number of compiles in progress static uint32_t compile_count = 0; static std::mutex compile_count_mutex; @@ -1473,7 +1536,7 @@ static std::array fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ // small rows, large cols if (small_rows) { - return {flash_attention_num_small_rows, 128}; + return {flash_attention_num_small_rows, 64}; } // small cols to reduce register count if (ggml_is_quantized(type) || D == 256) { @@ -1674,19 +1737,14 @@ static void ggml_vk_load_shaders(vk_device& device) { m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; - const uint32_t tm_int_l = device->coopmat_int_support ? device->coopmat_int_m : 4; - const uint32_t tm_int_m = device->coopmat_int_support ? device->coopmat_int_m : 4; - const uint32_t tm_int_s = device->coopmat_int_support ? device->coopmat_int_m : 2; - const uint32_t tn_int_l = device->coopmat_int_support ? device->coopmat_int_n : 4; - const uint32_t tn_int_m = device->coopmat_int_support ? device->coopmat_int_n : 2; - const uint32_t tn_int_s = device->coopmat_int_support ? device->coopmat_int_n : 2; - const uint32_t tk_int_l = device->coopmat_int_support ? device->coopmat_int_k : 1; - const uint32_t tk_int_m = device->coopmat_int_support ? device->coopmat_int_k : 1; - const uint32_t tk_int_s = device->coopmat_int_support ? device->coopmat_int_k : 1; + l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; + m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; + s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; - l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_int_l, tn_int_l, tk_int_l, subgroup_size_8 }; - m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_int_m, tn_int_m, tk_int_m, subgroup_size_8 }; - s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_int_s, tn_int_s, tk_int_s, subgroup_size_8 }; + // chip specific tuning + if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { + m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + } l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; @@ -1781,6 +1839,8 @@ static void ggml_vk_load_shaders(vk_device& device) { // can't use 256 for D==80. uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128; auto rows_cols = fa_rows_cols(D, clamp, type, small_rows); + // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads + GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); return {wg_size, rows_cols[0], rows_cols[1], (D), clamp}; }; @@ -2329,6 +2389,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { @@ -2342,7 +2403,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -2603,6 +2664,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->physical_device.getProperties2(&props2); device->properties = props2.properties; device->vendor_id = device->properties.vendorID; + device->driver_id = driver_props.driverID; const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); @@ -3348,6 +3410,7 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->prealloc_size_split_k = 0; ctx->fence = ctx->device->device.createFence({}); + ctx->almost_ready_fence = ctx->device->device.createFence({}); #ifdef GGML_VULKAN_CHECK_RESULTS const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS"); @@ -4138,6 +4201,12 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int if (split_k == 3) { split_k = 2; } + if (ctx->device->coopmat2) { + // coopmat2 shader expects splits to be aligned to 256 + while (split_k > 1 && ((k / split_k) % 256) != 0) { + split_k /= 2; + } + } } } @@ -5402,7 +5471,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const uint32_t nbm1 = mask ? mask->nb[1] : 0; const uint32_t D = neq0; - const uint32_t N = neq1; + uint32_t N = neq1; const uint32_t KV = nek1; GGML_ASSERT(ne0 == D); @@ -5457,12 +5526,60 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; + // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads + GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0); + vk_pipeline pipeline = pipelines[aligned]; assert(pipeline); + uint32_t gqa_ratio = 1; + uint32_t qk_ratio = neq2 / nek2; + uint32_t workgroups_x = (uint32_t)neq1; + uint32_t workgroups_y = (uint32_t)neq2; + uint32_t workgroups_z = (uint32_t)neq3; + + if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows && + qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { + // grouped query attention - make the N dimension equal to gqa_ratio, reduce + // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 + // and change addressing calculations to index Q's dimension 2. + gqa_ratio = qk_ratio; + N = gqa_ratio; + workgroups_y /= N; + } + + uint32_t split_kv = KV; + uint32_t split_k = 1; + + // Try to use split_k when KV is large enough to be worth the overhead + if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) { + // Try to run two workgroups per SM. + split_k = ctx->device->shader_core_count * 2 / workgroups_y; + if (split_k > 1) { + // Try to evenly split KV into split_k chunks, but it needs to be a multiple + // of "align", so recompute split_k based on that. + split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align); + split_k = CEIL_DIV(KV, split_kv); + workgroups_x = split_k; + } + } + + // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1) + // and the per-row m and L values (ne1 rows). + const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0; + if (split_k_size > ctx->device->max_memory_allocation_size) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (ctx->prealloc_size_split_k < split_k_size) { + ctx->prealloc_size_split_k = split_k_size; + } + if (dryrun) { // Request descriptor sets ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1); + } return; } @@ -5483,8 +5600,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_vk_sync_buffers(subctx); - vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr; size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0; @@ -5549,16 +5664,45 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx v_stride, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, - mask != nullptr, n_head_log2, m0, m1 }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - { - vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, - }, - sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 }); + mask != nullptr, n_head_log2, m0, m1, + gqa_ratio, split_kv, split_k }; + + ggml_vk_sync_buffers(subctx); + + if (split_k > 1) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, + }, + // We only use split_k when group query attention is enabled, which means + // there's no more than one tile of rows (i.e. workgroups_x would have been + // one). We reuse workgroups_x to mean the number of splits, so we need to + // cancel out the divide by wg_denoms[0]. + sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); + + ggml_vk_sync_buffers(subctx); + const std::array pc2 = { D, (uint32_t)ne1, split_k }; + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, + { + vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, + vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, + }, + pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 }); + } else { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, + }, + sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z }); + } } static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { @@ -5612,7 +5756,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_UPSCALE: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) { return ctx->device->pipeline_upscale_f32; } return nullptr; @@ -5869,6 +6013,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_ROPE: + case GGML_OP_RMS_NORM: return true; default: return false; @@ -6079,7 +6224,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co switch (op) { case GGML_OP_NORM: - case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: case GGML_OP_L2_NORM: case GGML_OP_SOFT_MAX: @@ -6096,6 +6240,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { nr, 1, 1 }; } } break; + case GGML_OP_RMS_NORM: + elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; + break; + case GGML_OP_SUM: // We use GGML_OP_SUM_ROWS with 1 row. elements = { 1, 1, 1 }; @@ -6746,7 +6894,17 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + op_params[0], 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); } static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -7786,7 +7944,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { 128, 49, 49, 4096, 49, 4096, }; - const size_t num_it = 1; + const size_t num_it = 100; ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0); ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0); @@ -7880,11 +8038,11 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } } -static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence); +static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); // Returns true if node has enqueued work into the queue, false otherwise // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. -static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool submit){ +static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){ if (ggml_is_empty(node) || !node->buffer) { return false; } @@ -8256,7 +8414,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod ctx->compute_ctx.reset(); - bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false); + bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready); if (!ok) { if (node->op == GGML_OP_UNARY) { std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; @@ -8270,7 +8428,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod return true; } -static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true){ +static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) { ggml_backend_buffer * buf = nullptr; switch (tensor->op) { @@ -8373,12 +8531,15 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * memcpy(cpy.dst, cpy.src, cpy.n); } - ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{}); + if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) { + ggml_vk_submit(subctx, ctx->almost_ready_fence); + ctx->almost_ready_fence_pending = true; + } else { + ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{}); + } if (use_fence) { - VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences"); - - ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_wait_for_fence(ctx); } #ifdef GGML_VULKAN_CHECK_RESULTS ggml_vk_check_results_1(tensor); @@ -8464,6 +8625,7 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ctx->gc.events.clear(); ctx->device->device.destroyFence(ctx->fence); + ctx->device->device.destroyFence(ctx->almost_ready_fence); } static int ggml_vk_get_device_count() { @@ -8810,8 +8972,7 @@ static void ggml_backend_vk_synchronize(ggml_backend_t backend) { } ggml_vk_submit(transfer_ctx, ctx->fence); - VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences"); - ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_wait_for_fence(ctx); for (auto& cpy : transfer_ctx->out_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); @@ -8830,7 +8991,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg uint64_t total_mat_mul_bytes = 0; for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false); + ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false); if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); } @@ -8872,11 +9033,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); } + // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) + bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; bool submit = (submitted_nodes >= nodes_per_submit) || (mul_mat_bytes >= mul_mat_bytes_per_submit) || - (i == last_node); + (i == last_node) || + (almost_ready && !ctx->almost_ready_fence_pending); - bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit); + bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit); if (enqueued) { ++submitted_nodes; @@ -8888,7 +9052,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg #endif } - if (submit) { + if (submit && enqueued) { first_node_in_batch = true; submitted_nodes = 0; mul_mat_bytes = 0; @@ -9118,6 +9282,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case 112: case 128: case 256: + case 575: // DeepSeek MLA break; default: return false; @@ -9244,10 +9409,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: + case GGML_OP_RMS_NORM: return true; case GGML_OP_NORM: case GGML_OP_GROUP_NORM: - case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: return ggml_is_contiguous(op->src[0]); case GGML_OP_ADD: @@ -9261,9 +9426,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_COS: case GGML_OP_CLAMP: return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_UPSCALE: + return op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_ACC: case GGML_OP_CONCAT: - case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_PAD: case GGML_OP_DIAG_MASK_INF: @@ -9631,7 +9797,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_CONCAT) { tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); } else if (tensor->op == GGML_OP_UPSCALE) { - tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->op_params[0], tensor->op_params[1], (ggml_scale_mode) tensor->op_params[0]); } else if (tensor->op == GGML_OP_SCALE) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index b1e1750219f..d6e0b2a5a5d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -1,3 +1,6 @@ +cmake_minimum_required(VERSION 3.19) +project("vulkan-shaders-gen" C CXX) + find_package (Threads REQUIRED) if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) @@ -6,6 +9,9 @@ endif() if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) endif() +if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) +endif() set(TARGET vulkan-shaders-gen) add_executable(${TARGET} vulkan-shaders-gen.cpp) install(TARGETS ${TARGET} RUNTIME) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index b3fad35e21d..962d2353f88 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -167,6 +167,101 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4 block_q4_K_packed128 block; }; +#if defined(IS_MUL_MM2) + +// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales +// into shared memory and then process the whole tile using those scales. +// There is a fetch function that loads into private variables and then a store +// function that stores into shared memory. +// Q4_K and Q5_K have the same encoding of scales, so everything is shared except +// the part that fetches from the structure (which has a different block layout). +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +const uint shAscales_stride = (BM + 2); +// 1 scale per 32 elements -> 8 scales per block, per row +shared vec2 shAscales[8 * shAscales_stride]; +uvec4 row_v; +#endif + +#if defined(DATA_A_Q4_K) +layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];}; + +void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds) +{ + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + uint row = ir_BM + tid_row; + uint block_index = pos_a + row * stride_a + (block_k / QUANT_K); + if (in_bounds || row < p.M) { + row_v = data_a_q4_k_packed128[block_index].q4k[0]; + } +} +#endif +#if defined(DATA_A_Q5_K) +layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];}; + +void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds) +{ + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + uint row = ir_BM + tid_row; + uint block_index = pos_a + row * stride_a + (block_k / QUANT_K); + if (in_bounds || row < p.M) { + row_v = data_a_q5_k_packed128[block_index].q5k[0]; + } +} +#endif + +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +void store_scalesQ4_K(uint tid) +{ + barrier(); + + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + [[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) { + uint is = idx + is_start; + uvec4 v = row_v; + const vec2 loadd = vec2(unpackFloat2x16(v.x)); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float d = loadd.x * float(sc); + const float m = loadd.y * float(mbyte); + shAscales[is * shAscales_stride + tid_row] = vec2(d,m); + } + + barrier(); +} +#endif + +#endif + float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl); @@ -176,8 +271,12 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2 const uint b = (idx & 0x20) >> 5; // 0,1 const uint is = (idx & 0xE0) >> 5; // 0..7 +#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else uvec4 v = bl128.block.q4k[0]; - const vec2 loadd = vec2(unpackFloat2x16(v.x)); uint32_t sc; @@ -201,6 +300,7 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2 const float d = loadd.x * float(sc); const float m = loadd.y * float(mbyte); +#endif uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF; @@ -231,6 +331,11 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2 const uint b = (idx & 0x20) >> 5; // 0,1 const uint is = (idx & 0xE0) >> 5; // 0..7 +#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else uvec4 v = bl128.block.q5k[0]; const f16vec2 loadd = unpackFloat2x16(v.x); @@ -256,6 +361,7 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2 const float16_t d = loadd.x * float16_t(sc); const float16_t m = loadd.y * float16_t(mbyte); +#endif uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]); qh = ((qh >> is) & 0x101) << 4; @@ -264,9 +370,9 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2 qs = (qs >> (b * 4)) & 0x0F0F; qs = unpack8(qs | qh)[idx & 1]; - float16_t ret = d * (float16_t(qs)) - m; + float ret = d * float(qs) - m; - return ret; + return float16_t(ret); } layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { @@ -564,8 +670,12 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor #define dequantFuncA dequantFuncQ3_K #elif defined(DATA_A_Q4_K) #define dequantFuncA dequantFuncQ4_K +#define fetch_scales fetch_scalesQ4_K +#define store_scales store_scalesQ4_K #elif defined(DATA_A_Q5_K) #define dequantFuncA dequantFuncQ5_K +#define fetch_scales fetch_scalesQ5_K +#define store_scales store_scalesQ4_K #elif defined(DATA_A_Q6_K) #define dequantFuncA dequantFuncQ6_K #elif defined(DATA_A_IQ1_S) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index df30355f635..b926a578ade 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -61,6 +61,10 @@ layout (push_constant) uniform parameter { uint32_t n_head_log2; float m0; float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; } p; layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; @@ -103,6 +107,38 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele #define DECODEFUNC #endif +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c < D) { + uint32_t offset = (iq2 + r) * D + c; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -111,12 +147,22 @@ void main() { const uint32_t N = p.N; const uint32_t KV = p.KV; + uint32_t i = gl_WorkGroupID.x; + uint32_t split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + const uint32_t Tr = CEIL_DIV(N, Br); - const uint32_t Tc = CEIL_DIV(KV, Bc); - const uint32_t i = gl_WorkGroupID.x; + const uint32_t start_j = split_k_index * p.split_kv / Bc; + const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); - const uint32_t iq2 = gl_WorkGroupID.y; + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; const uint32_t iq3 = gl_WorkGroupID.z; // broadcast factors @@ -149,10 +195,17 @@ void main() { tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); - // nb?1 are already divided by the type size and are in units of elements - uint32_t q_stride = p.nb01; + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; uint32_t k_stride = p.nb11; uint32_t v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; // hint to the compiler that strides are aligned for the aligned variant of the shader if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { @@ -161,6 +214,7 @@ void main() { k_stride &= ~7; v_stride &= ~7; #endif + m_stride &= ~7; } tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); @@ -179,23 +233,21 @@ void main() { coopmat L, M; + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + L = coopmat(0); - M = coopmat(-1.0/0.0); + M = coopmat(NEG_FLT_MAX_OVER_2); - ACC_TYPE slope = ACC_TYPE(1.0); + coopmat slopeMat = coopmat(1.0); // ALiBi if (p.max_bias > 0.0f) { - const uint32_t h = iq2; - - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); - - slope = pow(base, ACC_TYPE(exph)); + coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } [[dont_unroll]] - for (uint32_t j = 0; j < Tc; ++j) { + for (uint32_t j = start_j; j < end_j; ++j) { coopmat S = coopmat(0); @@ -213,14 +265,15 @@ void main() { } if (p.mask != 0) { - tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); coopmat mv; coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); - S += slope*coopmat(mv); + S += slopeMat*coopmat(mv); } // Clear padding elements to -inf, so they don't contribute to rowmax @@ -231,7 +284,7 @@ void main() { uint R = ((i + 1) * Br > N) ? (N % Br) : Br; uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; - coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C); + coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C); } coopmat rowmax, P, rowsum, eM; @@ -280,9 +333,25 @@ void main() { // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); - O = eMdiag * O; + // multiply with fp16 accumulation, then add to O. + coopmat PV = coopmat(0); + PV = coopMatMulAdd(P_A, V, PV); - O = coopMatMulAdd(P_A, V, O); + O = eMdiag * O + coopmat(PV); + } + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + coopmat O_D = coopmat(O); + + uint32_t o_offset = D * p.ne1 * split_k_index; + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + + o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); + coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); + return; } coopmat Ldiag; @@ -297,13 +366,18 @@ void main() { O = Ldiag*O; - tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); - tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); - - // permute dimensions - tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); uint32_t o_offset = iq3*p.ne2*p.ne1; coopmat O_D = coopmat(O); - coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute); + if (p.gqa_ratio > 1) { + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + } else { + tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); + + // permute dimensions + tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); + + coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute); + } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp new file mode 100644 index 00000000000..a7e3956854c --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -0,0 +1,59 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#define BLOCK_SIZE 32 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 1) writeonly buffer D {float data_d[];}; + +layout (push_constant) uniform parameter { + uint D; + uint N; + uint k_num; +} p; + +void main() { + // Each workgroup handles a row + const uint n = gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + uint D = p.D; + uint N = p.N; + uint k_num = p.k_num; + + uint l_offset = D * N * k_num + n; + uint m_offset = D * N * k_num + N + n; + uint lm_stride = N * 2; + + // Compute the max m value for the row + float m_max = -1.0/0.0; + [[unroll]] for (uint k = 0; k < k_num; ++k) { + float m = data_a[m_offset + k * lm_stride]; + m_max = max(m_max, m); + } + + // Compute L based on m_max + float L = 0; + [[unroll]] for (uint k = 0; k < k_num; ++k) { + float l = data_a[l_offset + k * lm_stride]; + float m = data_a[m_offset + k * lm_stride]; + L += exp(m - m_max) * l; + } + + L = 1.0 / L; + + // Scale and sum the O contributions based on m_max and store the result to memory + for (uint d = tid; d < D; d += BLOCK_SIZE) { + float O = 0.0; + [[unroll]] for (uint k = 0; k < k_num; ++k) { + uint o_offset = D * N * k + D * n + d; + float m = data_a[m_offset + k * lm_stride]; + O += exp(m - m_max) * data_a[o_offset]; + } + O *= L; + data_d[D * n + d] = O; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 7649febb071..06b7ab09ea5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -19,6 +19,9 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +#define IS_MUL_MM2 1 + +layout (constant_id = 0) const uint BLOCK_SIZE = 256; layout (constant_id = 1) const uint BM = 64; layout (constant_id = 2) const uint BN = 64; layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant @@ -70,6 +73,13 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #define DECODEFUNCA #endif +#if !defined(fetch_scales) +#define fetch_scales(a, b, c, d, e, f) +#endif +#if !defined(store_scales) +#define store_scales(a) +#endif + #ifdef MUL_MAT_ID layout (binding = 3) readonly buffer IDS {int data_ids[];}; @@ -116,6 +126,8 @@ void main() { init_iq_shmem(gl_WorkGroupSize); #endif + const uint tid = gl_LocalInvocationIndex; + #ifdef MUL_MAT_ID const uint expert_idx = gl_GlobalInvocationID.z; #else @@ -218,14 +230,21 @@ void main() { tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); #if !defined(MUL_MAT_ID) + + const uint START_ALIGN_K = 256; + // For Qi_K (block size 256), unroll whole 256 element tiles. + // For legacy quants (block size 32), unroll 8x. + const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8); + const uint unroll_count = UNROLL_K / BK; + // Detect a fast path where all loads are entirely in bounds and no clamping is required - if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % BK) == 0 && (end_k % BK) == 0 && + if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 && #if QUANT_K == 1 (stride_a % 8) == 0 && #endif - (stride_b % 8) == 0 && (start_k % 8) == 0) { + (stride_b % 8) == 0) { // Hint to the compiler that values are aligned (want 16B alignment) - start_k &= ~7; + start_k &= ~(START_ALIGN_K-1); stride_b &= ~7; #if QUANT_K == 1 stride_a &= ~7; @@ -234,11 +253,39 @@ void main() { tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); - uint k_iters = (end_k - start_k + BK - 1) / BK; + uint k_iters = (end_k - start_k) / UNROLL_K; + uint block_k = start_k; + + // fetch scale values for a tile of quants. These will be copied into shared memory. + // The fetches and stores are pipelined to hide the latency. + fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true); + if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) { coopmat sum = coopmat(0.0); - for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + for (uint i = 0; i < k_iters; ++i) { + + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { coopmat mat_a; coopmat mat_b; @@ -246,6 +293,7 @@ void main() { coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose); sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; } coopmat mat_d = coopmat(sum); @@ -253,8 +301,30 @@ void main() { return; } else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) { coopmat sum = coopmat(0.0); - for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + for (uint i = 0; i < k_iters; ++i) { + + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { coopmat mat_a; coopmat mat_b; @@ -262,6 +332,7 @@ void main() { coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose); sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; } coopmat mat_d = coopmat(sum); @@ -269,8 +340,31 @@ void main() { return; } else { coopmat sum = coopmat(0.0); - for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + for (uint i = 0; i < k_iters; ++i) { + + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { coopmat mat_a; coopmat mat_b; @@ -278,6 +372,7 @@ void main() { coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; } coopmat mat_d = coopmat(sum); @@ -298,47 +393,29 @@ void main() { coopmat sum; sum = coopmat(0.0); + uint k_iters = (end_k - start_k + BK - 1) / BK; + + fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false); + [[dont_unroll]] - for (uint block_k = start_k; block_k < end_k; block_k += BK) { + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + store_scales(tid); + if (block_k + BK < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); + } coopmat mat_a; coopmat mat_b; - // Clamping is expensive, so detect different code paths for each combination - // of A and B needing clamping. - bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0; + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID - bool unclampedB = true; + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); #else - bool unclampedB = (ic + 1) * BN <= p.padded_N && block_k + BK <= end_k && (block_k % 8) == 0; + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); #endif - if (unclampedA && unclampedB) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); -#ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); -#else - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); -#endif - sum = coopMatMulAdd(mat_a, mat_b, sum); - } else if (unclampedA && !unclampedB) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); - - sum = coopMatMulAdd(mat_a, mat_b, sum); - } else if (!unclampedA && unclampedB) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); -#ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); -#else - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); -#endif - sum = coopMatMulAdd(mat_a, mat_b, sum); - } else if (!unclampedA && !unclampedB) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); - sum = coopMatMulAdd(mat_a, mat_b, sum); - } + sum = coopMatMulAdd(mat_a, mat_b, sum); } // Convert from ACC_TYPE to D_TYPE diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 42f81356e8f..284a35caa68 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -234,9 +234,9 @@ void main() { #endif #if QUANT_AUXF == 1 - FLOAT_TYPE cache_a_dm[TM]; + FLOAT_TYPE cache_a_dm[WMITER * TM]; #else - FLOAT_TYPE_VEC2 cache_a_dm[TM]; + FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM]; #endif FLOAT_TYPE_VEC2 cache_b_ds[TN]; @@ -247,7 +247,6 @@ void main() { const uint iqs = loadr_a; const uint buf_ib = loadc_a + l; - // Should ds be gated to a single thread? if (iqs == 0) { #if QUANT_AUXF == 1 buf_a_dm[buf_ib] = get_d(ib); @@ -276,7 +275,6 @@ void main() { const uint buf_ib = loadc_b + l; - // Should ds be gated to a single thread? if (iqs == 0) { buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp index c4c35e105a7..63b15471bd3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp @@ -17,7 +17,7 @@ i32vec2 repack(uint ib, uint iqs) { } ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { - return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0 * dsb.y)); + return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y)); } #endif @@ -51,7 +51,7 @@ i32vec2 repack(uint ib, uint iqs) { } ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { - return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0 * dsb.y)); + return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y)); } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index b554400ba39..deb8ee9960f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -1,6 +1,6 @@ #version 450 -#include "generic_head.comp" +#include "generic_unary_head.comp" #include "types.comp" #extension GL_EXT_control_flow_attributes : enable @@ -8,19 +8,29 @@ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - shared FLOAT_TYPE sum[BLOCK_SIZE]; void main() { - const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; - const uint tid = gl_LocalInvocationID.x; + const uint ncols = p.ne00; + const uint nrows = gl_NumWorkGroups.x; + const uint nchannels = gl_NumWorkGroups.y; + + const uint row = gl_WorkGroupID.x; + const uint channel = gl_WorkGroupID.y; + const uint samp = gl_WorkGroupID.z; + const uint tid = gl_LocalInvocationID.x; + + const uint stride_row = p.nb01; + const uint stride_channel = p.nb02; + const uint stride_sample = p.nb03; + + uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); + uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]); sum[tid] += xi * xi; } @@ -33,10 +43,10 @@ void main() { barrier(); } - const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX); + const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols); const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 2ac4caee70e..cf74625cc56 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -465,6 +465,7 @@ void process_shaders() { string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); + string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {}); string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {}); string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 161dd3fa945..2a39dc7bfd1 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -956,6 +956,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CONV_TRANSPOSE_1D", "IM2COL", "IM2COL_BACK", + "CONV_2D_DW", "CONV_TRANSPOSE_2D", "POOL_1D", "POOL_2D", @@ -982,23 +983,18 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "UNARY", - "MAP_UNARY", - "MAP_BINARY", - - "MAP_CUSTOM1_F32", - "MAP_CUSTOM2_F32", - "MAP_CUSTOM3_F32", - "MAP_CUSTOM1", "MAP_CUSTOM2", "MAP_CUSTOM3", + "CUSTOM", + "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", "OPT_STEP_ADAMW", }; -static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85"); +static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1055,6 +1051,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "conv_transpose_1d(x)", "im2col(x)", "im2col_back(x)", + "conv_2d_dw(x)", "conv_transpose_2d(x)", "pool_1d(x)", "pool_2d(x)", @@ -1081,23 +1078,18 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "unary(x)", - "f(x)", - "f(x,y)", - - "custom_f32(x)", - "custom_f32(x,y)", - "custom_f32(x,y,z)", + "map_custom(x)", + "map_custom(x,y)", + "map_custom(x,y,z)", "custom(x)", - "custom(x,y)", - "custom(x,y,z)", "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", "adamw(x)", }; -static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85"); +static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1159,6 +1151,12 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) { } size_t ggml_nbytes(const struct ggml_tensor * tensor) { + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + if (tensor->ne[i] <= 0) { + return 0; + } + } + size_t nbytes; const size_t blck_size = ggml_blck_size(tensor->type); if (blck_size == 1) { @@ -1348,6 +1346,13 @@ bool ggml_is_permuted(const struct ggml_tensor * tensor) { return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3]; } +bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) { + return + tensor->nb[0] > tensor->nb[2] && + tensor->nb[1] > tensor->nb[0] && + tensor->nb[2] == ggml_type_size(tensor->type); +} + static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); @@ -4054,6 +4059,46 @@ struct ggml_tensor * ggml_conv_2d_dw( return result; } +// ggml_conv_2d_dw_direct + +struct ggml_tensor * ggml_conv_2d_dw_direct( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int stride0, + int stride1, + int pad0, + int pad1, + int dilation0, + int dilation1) { + GGML_ASSERT(a->ne[2] == 1); + GGML_ASSERT(a->ne[3] == b->ne[2]); + int64_t ne[4]; + ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0); + ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1); + ne[2] = b->ne[2]; + ne[3] = b->ne[3]; + + struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne); + + if (ggml_is_contiguous_channels(b)) { + // Result will be permuted the same way as input (CWHN order) + const int64_t type_size = ggml_type_size(result->type); + GGML_ASSERT(ggml_blck_size(result->type) == 1); + result->nb[0] = result->ne[2] * type_size; + result->nb[1] = result->ne[0] * result->nb[0]; + result->nb[2] = type_size; + } + + int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_CONV_2D_DW; + result->src[0] = a; + result->src[1] = b; + return result; +} + // ggml_conv_transpose_2d_p0 static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { @@ -4178,7 +4223,8 @@ static struct ggml_tensor * ggml_upscale_impl( int ne0, int ne1, int ne2, - int ne3) { + int ne3, + enum ggml_scale_mode mode) { GGML_ASSERT(a->ne[0] <= ne0); GGML_ASSERT(a->ne[1] <= ne1); GGML_ASSERT(a->ne[2] <= ne2); @@ -4186,6 +4232,8 @@ static struct ggml_tensor * ggml_upscale_impl( struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); + ggml_set_op_params_i32(result, 0, mode); + result->op = GGML_OP_UPSCALE; result->src[0] = a; @@ -4195,8 +4243,9 @@ static struct ggml_tensor * ggml_upscale_impl( struct ggml_tensor * ggml_upscale( struct ggml_context * ctx, struct ggml_tensor * a, - int scale_factor) { - return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]); + int scale_factor, + enum ggml_scale_mode mode) { + return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode); } struct ggml_tensor * ggml_upscale_ext( @@ -4205,8 +4254,9 @@ struct ggml_tensor * ggml_upscale_ext( int ne0, int ne1, int ne2, - int ne3) { - return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3); + int ne3, + enum ggml_scale_mode mode) { + return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode); } // ggml_pad @@ -4836,179 +4886,6 @@ struct ggml_tensor * ggml_unary_inplace( return ggml_unary_impl(ctx, a, op, true); } -// ggml_map_unary - -static struct ggml_tensor * ggml_map_unary_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_unary_op_f32_t fun, - bool inplace) { - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_UNARY; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_map_unary_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_unary_op_f32_t fun) { - return ggml_map_unary_impl_f32(ctx, a, fun, false); -} - -struct ggml_tensor * ggml_map_unary_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_unary_op_f32_t fun) { - return ggml_map_unary_impl_f32(ctx, a, fun, true); -} - -// ggml_map_binary - -static struct ggml_tensor * ggml_map_binary_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_binary_op_f32_t fun, - bool inplace) { - GGML_ASSERT(ggml_are_same_shape(a, b)); - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_BINARY; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_map_binary_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_binary_op_f32_t fun) { - return ggml_map_binary_impl_f32(ctx, a, b, fun, false); -} - -struct ggml_tensor * ggml_map_binary_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_binary_op_f32_t fun) { - return ggml_map_binary_impl_f32(ctx, a, b, fun, true); -} - -// ggml_map_custom1_f32 - -static struct ggml_tensor * ggml_map_custom1_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_f32_t fun, - bool inplace) { - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_CUSTOM1_F32; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_map_custom1_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_f32_t fun) { - return ggml_map_custom1_impl_f32(ctx, a, fun, false); -} - -struct ggml_tensor * ggml_map_custom1_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_f32_t fun) { - return ggml_map_custom1_impl_f32(ctx, a, fun, true); -} - -// ggml_map_custom2_f32 - -static struct ggml_tensor * ggml_map_custom2_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_f32_t fun, - bool inplace) { - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_CUSTOM2_F32; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_map_custom2_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_f32_t fun) { - return ggml_map_custom2_impl_f32(ctx, a, b, fun, false); -} - -struct ggml_tensor * ggml_map_custom2_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_f32_t fun) { - return ggml_map_custom2_impl_f32(ctx, a, b, fun, true); -} - -// ggml_map_custom3_f32 - -static struct ggml_tensor * ggml_map_custom3_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_f32_t fun, - bool inplace) { - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_CUSTOM3_F32; - result->src[0] = a; - result->src[1] = b; - result->src[2] = c; - - return result; -} - -struct ggml_tensor * ggml_map_custom3_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_f32_t fun) { - return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false); -} - -struct ggml_tensor * ggml_map_custom3_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_f32_t fun) { - return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true); -} - // ggml_map_custom1 static struct ggml_tensor * ggml_map_custom1_impl( @@ -5027,7 +4904,7 @@ static struct ggml_tensor * ggml_map_custom1_impl( /*.n_tasks =*/ n_tasks, /*.userdata =*/ userdata }; - ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_MAP_CUSTOM1; result->src[0] = a; @@ -5072,7 +4949,7 @@ static struct ggml_tensor * ggml_map_custom2_impl( /*.n_tasks =*/ n_tasks, /*.userdata =*/ userdata }; - ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_MAP_CUSTOM2; result->src[0] = a; @@ -5121,7 +4998,7 @@ static struct ggml_tensor * ggml_map_custom3_impl( /*.n_tasks =*/ n_tasks, /*.userdata =*/ userdata }; - ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_MAP_CUSTOM3; result->src[0] = a; @@ -5153,6 +5030,66 @@ struct ggml_tensor * ggml_map_custom3_inplace( return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true); } +struct ggml_tensor * ggml_custom_4d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + struct ggml_tensor ** args, + int n_args, + ggml_custom_op_t fun, + int n_tasks, + void * userdata) { + + GGML_ASSERT(n_args < GGML_MAX_SRC); + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3); + + struct ggml_custom_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + ggml_set_op_params(result, ¶ms, sizeof(params)); + + result->op = GGML_OP_CUSTOM; + for (int i = 0; i < n_args; i++) { + result->src[i] = args[i]; + } + + return result; +} + +struct ggml_tensor * ggml_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor ** args, + int n_args, + ggml_custom_op_t fun, + int n_tasks, + void * userdata) { + + GGML_ASSERT(n_args < GGML_MAX_SRC - 1); + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + struct ggml_custom_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + ggml_set_op_params(result, ¶ms, sizeof(params)); + + result->op = GGML_OP_CUSTOM; + result->src[0] = a; + for (int i = 0; i < n_args; i++) { + result->src[i + 1] = args[i]; + } + + return result; +} // ggml_cross_entropy_loss struct ggml_tensor * ggml_cross_entropy_loss( diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 4ade1b830d1..b1490b6256a 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -d920dfd7da37b22d1eb0813cdaf340c1870d76c3 +489716ba99ecd51164f79e8c6fec0b5bf634eac9