From 03e85da848f551e3ab3ea543c0fb7ba14a4c6e9f Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 12 Jul 2025 21:31:38 +0800 Subject: [PATCH 01/20] CUDA: add set rows for f32 and f16 (llama/14551) * CUDA: add set rows for f32 and f16 * Review: change kernel params, use strides from host * Use 1-d kernel * Review: use int64_t for blockDim.x, rename nb->s for clarity --- ggml/src/ggml-cuda/ggml-cuda.cu | 10 +++ ggml/src/ggml-cuda/set-rows.cu | 130 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/set-rows.cuh | 7 ++ 3 files changed, 147 insertions(+) create mode 100644 ggml/src/ggml-cuda/set-rows.cu create mode 100644 ggml/src/ggml-cuda/set-rows.cuh diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 72406f0af36..88b17dd682c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -43,6 +43,7 @@ #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" +#include "ggml-cuda/set-rows.cuh" #include "ggml.h" #include @@ -2230,6 +2231,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GET_ROWS_BACK: ggml_cuda_op_get_rows_back(ctx, dst); break; + case GGML_OP_SET_ROWS: + ggml_cuda_op_set_rows(ctx, dst); + break; case GGML_OP_DUP: ggml_cuda_dup(ctx, dst); break; @@ -3216,6 +3220,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g { return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; } break; + case GGML_OP_SET_ROWS: + { + return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_I64; + } break; case GGML_OP_CPY: { ggml_type src0_type = op->src[0]->type; diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu new file mode 100644 index 00000000000..d8b3e63e1aa --- /dev/null +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -0,0 +1,130 @@ +#include "set-rows.cuh" + +typedef void (*set_rows_kernel_t)(const char * src, char * dst); + +template +__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {} + +template<> +__device__ __forceinline__ void set_rows_1(const float * src_f, half * dst_h) { + *dst_h = __float2half(*src_f); +} + +template<> +__device__ __forceinline__ void set_rows_1(const float * src_f, float * dst_f) { + *dst_f = *src_f; +} + +template +static __global__ void k_set_rows( + const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t s10, const int64_t s11, const int64_t s12, + const int64_t s1, const int64_t s2, const int64_t s3) { + + const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; + const int64_t ne_total = ne00 * ne01 * ne02 * ne03; + + if (i >= ne_total) { + return; + } + + const int64_t i03 = i / (ne00 * ne01 * ne02); + const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00; + const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00; + + const int64_t i12 = i03 % ne12; + const int64_t i11 = i02 % ne11; + const int64_t i10 = i01; + + const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + + const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; + dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3; + + const src_t* src_elem = src0_row + i00; + dst_t* dst_elem = dst_row_ptr + i00; + set_rows_1(src_elem, dst_elem); +} + +template +static void set_rows_cuda( + const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const size_t nb01, const size_t nb02, const size_t nb03, + const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + cudaStream_t stream) { + + const int64_t ne_total = ne00 * ne01 * ne02 * ne03; + const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE; + const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE); + const dim3 grid_size(num_blocks); + + + const int64_t s01 = nb01/sizeof(src_t); + const int64_t s02 = nb02/sizeof(src_t); + const int64_t s03 = nb03/sizeof(src_t); + const int64_t s10 = nb10/sizeof(int64_t); + const int64_t s11 = nb11/sizeof(int64_t); + const int64_t s12 = nb12/sizeof(int64_t); + const int64_t s1 = nb1/sizeof(dst_t); + const int64_t s2 = nb2/sizeof(dst_t); + const int64_t s3 = nb3/sizeof(dst_t); + + if (ne_total > 0) { + k_set_rows<<>>( + src0_d, src1_d, dst_d, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + s01, s02, s03, + s10, s11, s12, + s1, s2, s3); + } +} + + +void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_I64); + + GGML_TENSOR_BINARY_OP_LOCALS + + const float * src0_d = (const float *)src0->data; + const int64_t * src1_d = (const int64_t *)src1->data; + + cudaStream_t stream = ctx.stream(); + + + + if (dst->type == GGML_TYPE_F32) { + set_rows_cuda( + src0_d, src1_d, (float*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_F16) { + set_rows_cuda( + src0_d, src1_d, (half*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else { + GGML_ABORT("unsupported type"); + } +} diff --git a/ggml/src/ggml-cuda/set-rows.cuh b/ggml/src/ggml-cuda/set-rows.cuh new file mode 100644 index 00000000000..c140c0873c8 --- /dev/null +++ b/ggml/src/ggml-cuda/set-rows.cuh @@ -0,0 +1,7 @@ +#pragma once + +#include "common.cuh" + +#define CUDA_SET_ROWS_BLOCK_SIZE 256 + +void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From a6b85bc8dde7dcfe0ab19acef66adde91d37c838 Mon Sep 17 00:00:00 2001 From: Yavor Ivanov Date: Sat, 12 Jul 2025 22:38:13 -0700 Subject: [PATCH 02/20] metal : Add missing unary ops Metal support (llama/14660) --- ggml/src/ggml-metal/ggml-metal.m | 90 ++++++++++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal.metal | 45 ++++++++++++++ 2 files changed, 135 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 83a0739809a..44ddc69d08f 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -173,6 +173,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_SILU, GGML_METAL_KERNEL_TYPE_SILU_4, GGML_METAL_KERNEL_TYPE_ELU, + GGML_METAL_KERNEL_TYPE_ABS, + GGML_METAL_KERNEL_TYPE_SGN, + GGML_METAL_KERNEL_TYPE_STEP, + GGML_METAL_KERNEL_TYPE_HARDSWISH, + GGML_METAL_KERNEL_TYPE_HARDSIGMOID, + GGML_METAL_KERNEL_TYPE_EXP, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, @@ -1155,6 +1161,12 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP, step, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP, exp, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction); @@ -1688,6 +1700,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_HARDSWISH: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_EXP: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; @@ -2439,6 +2457,78 @@ static bool ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_UNARY_OP_ABS: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SGN: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_STEP: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_HARDSWISH: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_HARDSIGMOID: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_EXP: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; default: { GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 239ec31fbcb..13235e28852 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1199,6 +1199,51 @@ kernel void kernel_neg( dst[tpig] = -src0[tpig]; } +kernel void kernel_abs( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = fabs(src0[tpig]); +} + +kernel void kernel_sgn( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = (x > 0.0f) ? 1.0f : ((x < 0.0f) ? -1.0f : 0.0f); +} + +kernel void kernel_step( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] > 0.0f ? 1.0f : 0.0f; +} + +kernel void kernel_hardswish( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); +} + +kernel void kernel_hardsigmoid( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); +} + +kernel void kernel_exp( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = exp(src0[tpig]); +} + kernel void kernel_reglu( device const char * src0, device const char * src1, From 8610c4cd64ddad03a14b68e8e6964dc1d5084ba5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 13 Jul 2025 10:36:33 +0300 Subject: [PATCH 03/20] ggml : add build-time message to remind about ggml_set_rows (llama/14661) ggml-ci --- ggml/src/ggml-cann/ggml-cann.cpp | 1 + ggml/src/ggml-cuda/ggml-cuda.cu | 1 + ggml/src/ggml-opencl/ggml-opencl.cpp | 1 + ggml/src/ggml-sycl/ggml-sycl.cpp | 1 + 4 files changed, 4 insertions(+) diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index ccb17eb072e..e5e11d4cdce 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2090,6 +2090,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, { // TODO: add support // ref: https://github.com/ggml-org/llama.cpp/pull/14274 +#pragma message("TODO: implement F32, F16, BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") return false; } break; case GGML_OP_CPY: { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 88b17dd682c..1478245998a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3222,6 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } break; case GGML_OP_SET_ROWS: { +#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64; diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 58830b733a8..3388259152b 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2280,6 +2280,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te { // TODO: add support // ref: https://github.com/ggml-org/llama.cpp/pull/14274 +#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") if (op->src[0]->type != GGML_TYPE_F32) { return false; } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 65b26fd0276..7f74fbfe5c1 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4303,6 +4303,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g { // TODO: add support // ref: https://github.com/ggml-org/llama.cpp/pull/14274 +#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") return (op->type == GGML_TYPE_F32 || (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64)); } break; case GGML_OP_CPY: From e6509f7b1b635dbb4c7dfb9a1af4b6ccf97bd726 Mon Sep 17 00:00:00 2001 From: Yavor Ivanov Date: Sun, 13 Jul 2025 02:33:16 -0700 Subject: [PATCH 04/20] cuda : add ELU support (llama/14657) --- ggml/src/ggml-cuda/ggml-cuda.cu | 4 ++++ ggml/src/ggml-cuda/unary.cu | 7 +++++++ ggml/src/ggml-cuda/unary.cuh | 2 ++ 3 files changed, 13 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 1478245998a..c7222207efe 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2303,6 +2303,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_EXP: ggml_cuda_op_exp(ctx, dst); break; + case GGML_UNARY_OP_ELU: + ggml_cuda_op_elu(ctx, dst); + break; default: return false; } @@ -3116,6 +3119,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_ELU: return ggml_is_contiguous(op->src[0]); default: return false; diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index f9c7b83c40d..91c830c4dac 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -83,6 +83,10 @@ static __device__ __forceinline__ float op_log(float x) { return logf(x); } +static __device__ __forceinline__ float op_elu(float x) { + return (x > 0.f) ? x : expm1f(x); +} + template static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -196,6 +200,9 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } +void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} /* gated ops */ template diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 289d690e5cf..cb14d16f8f3 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -59,6 +59,8 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 6055fb47e05acac9480c3d84cfcac531460bb235 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 13 Jul 2025 15:01:24 +0200 Subject: [PATCH 05/20] cuda : add set rows for bf16 (llama/14664) --- ggml/src/ggml-cuda/ggml-cuda.cu | 4 ++-- ggml/src/ggml-cuda/set-rows.cu | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c7222207efe..8015b0d4e8d 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3226,8 +3226,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } break; case GGML_OP_SET_ROWS: { -#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") - return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && +#pragma message("TODO: implement Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") + return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16) && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64; } break; diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index d8b3e63e1aa..3fade72b84e 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -10,6 +10,11 @@ __device__ __forceinline__ void set_rows_1(const float * src_f, hal *dst_h = __float2half(*src_f); } +template<> +__device__ __forceinline__ void set_rows_1(const float * src_f, nv_bfloat16 * dst_b) { + *dst_b = *src_f; +} + template<> __device__ __forceinline__ void set_rows_1(const float * src_f, float * dst_f) { *dst_f = *src_f; @@ -124,6 +129,16 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { nb1, nb2, nb3, stream ); + } else if (dst->type == GGML_TYPE_BF16) { + set_rows_cuda( + src0_d, src1_d, (nv_bfloat16*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); } else { GGML_ABORT("unsupported type"); } From 24643c0e64fe3f54ab12f89c01500aaf9a851222 Mon Sep 17 00:00:00 2001 From: Anton Mitkov Date: Mon, 14 Jul 2025 10:37:35 +0100 Subject: [PATCH 06/20] sycl: Batched mulmat rework for oneDNN dispatch (llama/14617) --- ggml/src/ggml-sycl/gemm.hpp | 40 +++----- ggml/src/ggml-sycl/ggml-sycl.cpp | 165 ++++++++++++++++++++++--------- 2 files changed, 133 insertions(+), 72 deletions(-) diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index 5efe03d364b..dcf6c7aeeb4 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -32,39 +32,28 @@ class DnnlGemmWrapper { else static_assert(0); } - // matrix A has m rows, k columns - // matrix B has k rows, n columns - // nra - number of elements to skip when moving into next row in A - // nrb - number of elements to skip when moving into next row in B - // nca - number of elements to skip when moving into next column in A - // ncb - number of elements to skip when moving into next column in B - // stride_a - number of elements to skip when moving to next A matrix - // stride_b - number of elements to skip when moving to next B matrix - // batches_a - number of A matrices - // batches_b - number of B matrices static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, - const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a, - const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b, + const void * a, dt at, dnnl_dim_t stra0, dnnl_dim_t stra1, dnnl_dim_t stra2, + const void * b, dt bt, dnnl_dim_t strb0, dnnl_dim_t strb1, dnnl_dim_t strb2, void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) { auto stream = ctx.stream_dnnl(q); auto eng = ctx.engine_dnnl(q); - // { # strides, # rows, # columns } - dnnl::memory::dims a_dims = { batches_a, m, k }; - dnnl::memory::dims b_dims = { batches_b, k, n }; - dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n }; - - // { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column } - dnnl::memory::dims a_strides = { stride_a, nra, nca }; - dnnl::memory::dims b_strides = { stride_b, nrb, ncb }; - + dnnl::memory::dims a_dims = {batches_a, m, k }; + dnnl::memory::dims a_strides = {stra2, stra1, stra0}; const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides); + + dnnl::memory::dims b_dims = {batches_b, k, n }; + dnnl::memory::dims b_strides = {strb2, strb0, strb1}; const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides); - const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc); + dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n}; + dnnl::memory::dims c_strides = {m*n, 1, m }; + const auto c_md = dnnl::memory::desc(c_dims, ct, c_strides); dnnl::primitive_attr primitive_attr; primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + #ifdef GGML_SYCL_F16 primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16); #endif @@ -76,24 +65,23 @@ class DnnlGemmWrapper { auto scratchpad_md = matmul_pd.scratchpad_desc(); auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q); + auto matmul_prim = dnnl::matmul(matmul_pd); std::unordered_map matmul_args; matmul_args.insert({ DNNL_ARG_SRC, a_mem }); matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); + matmul_args.insert({ DNNL_ARG_DST, c_mem }); matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem }); matmul_prim.execute(stream, matmul_args); } - // matrices A and B are column major, both having k rows - // matrix A has m column, matrix B has n columns - // output: column major matrix C = A transposed * B static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { - gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1); + gemm(ctx, m, n, k, a, at, 1, k, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1); } }; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 7f74fbfe5c1..cf46012be81 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1546,7 +1546,7 @@ static void mul_mat_p021_f16_f32( static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, - const int row_stride_x, const int channel_stride_x, const int channel_x_divisor, + const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor, const sycl::nd_item<3> &item_ct1) { const sycl::half *x = (const sycl::half *)vx; @@ -1557,7 +1557,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous item_ct1.get_local_id(0); const int channel_x = channel / channel_x_divisor; - const int nrows_y = ncols_x; const int nrows_dst = nrows_x; const int row_dst = row_x; @@ -1576,7 +1575,7 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous const int row_y = col_x; const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; - const int iy = channel*nrows_y + row_y; + const int iy = channel * channel_stride_y + row_y; const float xi = sycl::vec(x[ix]) @@ -1823,7 +1822,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y, static void ggml_mul_mat_vec_nc_f16_f32_sycl( const void *vx, const float *y, float *dst, const int ncols_x, const int nrows_x, const int row_stride_x, const int nchannels_x, - const int nchannels_y, const int channel_stride_x, queue_ptr stream) { + const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) { const sycl::range<3> block_nums(nchannels_y, nrows_x, 1); const sycl::range<3> block_dims(1, 1, WARP_SIZE); @@ -1835,7 +1834,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x, - row_stride_x, channel_stride_x, + row_stride_x, channel_stride_x, channel_stride_y, nchannels_y / nchannels_x, item_ct1); }); } @@ -2124,8 +2123,8 @@ inline void ggml_sycl_op_mul_mat_sycl( #if GGML_SYCL_DNNL if (!g_ggml_sycl_disable_dnn) { - DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr, - DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(), + DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr, + DnnlGemmWrapper::to_dt(), src1_ptr, DnnlGemmWrapper::to_dt(), dst_dd_i, DnnlGemmWrapper::to_dt(), stream); } else @@ -2171,8 +2170,8 @@ inline void ggml_sycl_op_mul_mat_sycl( #if GGML_SYCL_DNNL if (!g_ggml_sycl_disable_dnn) { - DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i, - DnnlGemmWrapper::to_dt(), src0_ddf_i, DnnlGemmWrapper::to_dt(), + DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i, + DnnlGemmWrapper::to_dt(), src1_ddf1_i, DnnlGemmWrapper::to_dt(), dst_dd_i, DnnlGemmWrapper::to_dt(), stream); } else @@ -2776,6 +2775,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml const int64_t nb02 = src0->nb[2]; const int64_t ne12 = src1->ne[2]; + const int64_t nb11 = src1->nb[1]; SYCL_CHECK(ggml_sycl_set_device(ctx.device)); queue_ptr main_stream = ctx.stream(); @@ -2786,8 +2786,9 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml const int64_t row_stride_x = nb01 / sizeof(sycl::half); const int64_t channel_stride_x = nb02 / sizeof(sycl::half); + const int64_t channel_stride_y = nb11 / sizeof(float); - ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); + ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream); } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -2841,8 +2842,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons float * dst_ddf = static_cast(dst->data); const sycl::half * src1_f16 = static_cast(src1->data); + const size_t type_size_src0 = ggml_type_size(src0->type); const size_t type_size_src1 = ggml_type_size(src1->type); - GGML_ASSERT(nb10 == type_size_src1); // SRC1 strides int64_t s11 = nb11 / type_size_src1; @@ -2854,11 +2855,32 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons if (src1->type != GGML_TYPE_F16) { scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2, " : converting src1 to fp16"); - const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type); - GGML_ASSERT(to_fp16_nc_sycl != nullptr); - const int64_t ne_src1 = ggml_nelements(src1); + + // iterate tensor dims and find the slowest moving dim and stride + int64_t last_dim=0; + int64_t last_str=0; + int64_t largest_str=0; + for(int i = 0; i< 4; i++){ + // last stride is always the largest + if(src1->nb[i] == largest_str){ + if(src1->ne[last_dim] == 1){ + last_str = i; + last_dim = i; + } + } + if(src1->nb[i] > largest_str){ + largest_str = src1->nb[i]; + last_str = i; + last_dim = i; + } + + } + const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1; src1_f16_alloc.alloc(ne_src1); - to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue); + + const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); + GGML_ASSERT(to_fp16_sycl != nullptr); + to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue); src1_f16 = src1_f16_alloc.get(); s11 = ne10; @@ -2892,38 +2914,89 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons #if GGML_SYCL_DNNL if (!g_ggml_sycl_disable_dnn) { - auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12] - (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) { - - DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10, - src1, DnnlGemmWrapper::to_dt(), s11, 1, s12, - src0, DnnlGemmWrapper::to_dt(), 1, nb01/nb00, nb02/nb00, - dst, DnnlGemmWrapper::to_dt(), queue, batches_a, batches_b); - }; - - if (r2 == 1 && r3 == 1) { - if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { - dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03); - } - else { - for (int64_t ie03 = 0; ie03 < ne03; ++ie03) { - const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes - const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13; - float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float)); - dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02); + int64_t str_a0 = nb00 / type_size_src0; + int64_t str_a1 = nb01 / type_size_src0; + int64_t str_a2 = nb02 / type_size_src0; + + int64_t str_b0 = nb10 / type_size_src1; + int64_t str_b1 = nb11 / type_size_src1; + int64_t str_b2 = nb12 / type_size_src1; + + auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0, + const sycl::half *src1, float *dst, + int64_t a0, int64_t a1, int64_t batcha, + int64_t b0, int64_t b1, int64_t batchb, + int64_t sa0, int64_t sa1, int64_t sa2, + int64_t sb0, int64_t sb1, int64_t sb2, + int64_t sd2) { + bool supported_broadcast = batchb == batcha ? true + : batchb == 1 || batcha == 1 ? true + : false; + if (supported_broadcast) { + DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0, + DnnlGemmWrapper::to_dt(), sa0, sa1, sa2, src1, + DnnlGemmWrapper::to_dt(), sb0, sb1, sb2, dst, + DnnlGemmWrapper::to_dt(), queue, batcha, batchb); + } else { + // iterate over batches from smaller set of matrices (matrix 0) + int64_t batches0 = batcha; + int64_t batches1 = batchb; + + if (batches0 > batches1) { + int64_t num_mul_mats = batches1; + int64_t sub_batch = batches0 / num_mul_mats; + // src0 is batched and bigger, shift and multiply with src1 + for (int64_t i0 = 0; i0 < num_mul_mats; i0++) { + const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch); + const sycl::half *src1_shifted = src1 + (sb2 * i0); + float *dst_shifted = dst + (sd2 * i0 * sub_batch); + DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted, + DnnlGemmWrapper::to_dt(), sa0, sa1, sa2, + src1_shifted, DnnlGemmWrapper::to_dt(), sb0, + sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt(), + queue, sub_batch, 1); + } + } else { + int64_t num_mul_mats = batches0; + int64_t sub_batch = batches1 / num_mul_mats; + // src1 is batched and bigger, shift and multiply with src0 + for (int64_t i1 = 0; i1 < num_mul_mats; i1++) { + const sycl::half *src0_shifted = src0 + (sa2 * i1); + const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch); + float *dst_shifted = dst + (sd2 * i1 * sub_batch); + DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted, + DnnlGemmWrapper::to_dt(), sa0, sa1, sa2, + src1_shifted, DnnlGemmWrapper::to_dt(), sb0, + sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt(), + queue, 1, sub_batch); + } + } } - } - } else { - // iterate over batches from smaller set of matrices (matrix 0) - for (int64_t ie02 = 0; ie02 < ne02; ++ie02) { - for (int64_t ie03 = 0; ie03 < ne03; ++ie03) { - const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half)); - const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3; - float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float)); - dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1); + }; + + bool cont_batches_a = nb02 * ne02 == nb03; + bool cont_batches_b = nb12 * ne12 == nb13; + if (cont_batches_a && cont_batches_b) { + int64_t batches0 = ne02 * ne03; + int64_t batches1 = ne12 * ne13; + launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0, + ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1, + str_b2, nb2 / sizeof(float)); + } else { + for (int64_t b_a = 0; b_a < ne03; b_a++) { + const sycl::half *src0_f16_shifted + = src0_f16 + (nb03 * b_a / type_size_src0); + const sycl::half *src1_f16_shifted + = src1_f16 + (nb13 * b_a / type_size_src1); + float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float)); + int64_t batches0 = ne02; + int64_t batches1 = ne12; + launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted, + ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1, + str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float)); } } - } + } else #endif @@ -3263,10 +3336,10 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor // The kernel from the if path is faster for that specific case, but does not support all mul mats. ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } - } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst); - } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) { // KQ + KQV multi-batch ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { From e5d8efc1d93a362150534a8b698bd80b3dab28e5 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Mon, 14 Jul 2025 15:07:55 +0530 Subject: [PATCH 07/20] SYCL: use 1D kernel for set_rows (llama/14618) * SYCL: Use 1D kernel for set_rows * Remove dangling comment * Refactor and use ceil_div --- ggml/src/ggml-sycl/set_rows.cpp | 86 ++++++++++++++++----------------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/ggml/src/ggml-sycl/set_rows.cpp b/ggml/src/ggml-sycl/set_rows.cpp index 4a76a63d354..3091fab3995 100644 --- a/ggml/src/ggml-sycl/set_rows.cpp +++ b/ggml/src/ggml-sycl/set_rows.cpp @@ -6,46 +6,49 @@ static constexpr bool is_arithmetic_v() { return std::is_arithmetic_v || std::is_same_v || std::is_same_v; } } + template static inline std::enable_if_t() && utils::is_arithmetic_v(), void> convert (const char* src, char* dst) { auto src_val = *reinterpret_cast(src); auto dst_val = sycl::vec(src_val).template convert()[0]; - *reinterpret_cast(dst) = dst_val;; + *reinterpret_cast(dst) = dst_val; } template static void k_set_rows( const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst, - const int64_t ne00, const int64_t ne01, const int64_t ne11, const int64_t ne12, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t ne11, const int64_t ne12, const size_t nb01, const size_t nb02, const size_t nb03, const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb1, const size_t nb2, const size_t nb3, const size_t src_type_size, const size_t dst_type_size, - const sycl::nd_item<3> & item_ct1) { - - const int i03 = item_ct1.get_group(0); - const int i02 = item_ct1.get_group(1); - const int i01 = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); // Row index + const int64_t total_elements, + const sycl::nd_item<1> & item_ct1) { - if (i01 >= ne01) { + const int64_t i = item_ct1.get_global_linear_id(); + if (i >= total_elements) { return; } - const int i12 = i03 % ne12; - const int i11 = i02 % ne11; - const int i10 = i01; + const int64_t i03 = i / (ne00 * ne01 * ne02); + const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00; + const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00; + + const int64_t i12 = i03 % ne12; + const int64_t i11 = i02 % ne11; + const int64_t i10 = i01; const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset<3>({nb10, nb11, nb12}, {i10, i11, i12})); const char * src0_row = src0 + calculate_offset<3>({nb01, nb02, nb03}, {i01, i02, i03}); - char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3; + const char * src_elem = src0_row + i00 * src_type_size; + char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3; + char * dst_elem = dst_row_ptr + i00 * dst_type_size; - for (int col = item_ct1.get_local_id(0); col < ne00; col += item_ct1.get_local_range(0)) { - const char * src_elem = src0_row + col * src_type_size; - char * dst_elem = dst_row_ptr + col * dst_type_size; - convert(src_elem, dst_elem); - } + convert(src_elem, dst_elem); } template @@ -58,32 +61,29 @@ static void set_rows_sycl( const size_t src_type_size, const size_t dst_type_size, queue_ptr stream) { - constexpr int max_threads_per_row = 64; // KEEPING 64 for now - const int threads_per_row = std::min((int)ne00, max_threads_per_row); - - constexpr int max_threads_per_block = 64; - const int rows_per_block = std::max(1, max_threads_per_block / threads_per_row); - - const sycl::range<3> block_size(1, rows_per_block, threads_per_row); - const sycl::range<3> grid_size(ne03, ne02, (ne01 + rows_per_block - 1) / rows_per_block); - - sycl_parallel_for( - stream, - sycl::nd_range<3>(grid_size * block_size, block_size), - [=](sycl::nd_item<3> item_ct1) { - k_set_rows( - src0_d, src1_d, dst_d, - ne00, ne01, ne11, ne12, - nb01, nb02, nb03, - nb10, nb11, nb12, - nb1, nb2, nb3, - src_type_size, dst_type_size, - item_ct1 - ); - } - ); -} + const int64_t total_elements = ne00 * ne01 * ne02 * ne03; + constexpr int block_size = 64; + const int64_t grid_size = ceil_div(total_elements, block_size); + + sycl_parallel_for( + stream, + sycl::nd_range<1>(grid_size * block_size, block_size), + [=](sycl::nd_item<1> item_ct1) { + k_set_rows( + src0_d, src1_d, dst_d, + ne00, ne01, ne02, + ne11, ne12, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + src_type_size, dst_type_size, + total_elements, + item_ct1 + ); + } + ); +} void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); @@ -122,7 +122,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { nb1, nb2, nb3, sizeof(float), sizeof(sycl::half), stream - ); + ); break; default: GGML_ABORT("Unsupported tensor type!"); From 07b4522c02249119879454ba506e836d72005514 Mon Sep 17 00:00:00 2001 From: shalinib-ibm Date: Mon, 14 Jul 2025 18:46:42 +0530 Subject: [PATCH 08/20] ggml : refactor llamafile_sgemm PPC code (llama/14673) Remove un-necessary templates from class definition and packing functions Reduce deeply nested conditionals, if-else switching in mnapck function Replace repetitive code with inline functions in Packing functions 2 ~ 7% improvement in Q8 Model 15 ~ 50% improvement in Q4 Model Signed-off-by: Shalini Salomi Bodapati --- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 1437 ++++++------------------- 1 file changed, 343 insertions(+), 1094 deletions(-) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index ed61869a550..2be54c31b5f 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -1541,7 +1541,7 @@ class tinyBLAS_BF16_PPC { } else if constexpr(RM == 8 && RN == 4) { KERNEL_8x4(ii,jj); } else { - static_assert(false, "RN/RM values not supported"); + assert(false && "RN/RM values not supported"); } } @@ -1573,13 +1573,13 @@ class tinyBLAS_BF16_PPC { const int nth; }; -template +template class tinyBLAS_Q0_PPC { public: tinyBLAS_Q0_PPC(int64_t k, const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, + const block_q8_0 *B, int64_t ldb, + float *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } @@ -1590,8 +1590,7 @@ class tinyBLAS_Q0_PPC { private: - template - inline void save_res(int ii, int jj, int idx, vector float* fin_res) { + inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { for (int I = 0; I < RM; I++) { for (int J = 0; J < RN; J++) { *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J); @@ -1611,29 +1610,67 @@ class tinyBLAS_Q0_PPC { fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]); } } - - template - void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array& comparray) { - int64_t i, j; - TA *aoffset = NULL; - VA *vecOffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; - VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - VB t1, t2, t3, t4, t5, t6, t7, t8; + /* This function processes quantized data from block_q4_0 elements. + * First the we try to extract the two int4 values stored in single int8_t into two signed int8. + * And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8. + * Also compute the rowsum which is required to compensate the above conversion. */ + inline void process_q4_elements(vector signed char (&c)[2], int* ca) { const vector signed char lowMask = vec_splats((signed char)0xF); const vector unsigned char v4 = vec_splats((unsigned char)0x4); const vector signed char v8 = vec_splats((signed char)0x8); - aoffset = const_cast(a); - vecOffset = vec; + vector signed int vsum = {0}; + vector signed int vsum2 = {0}; + c[0] = vec_and(c[1], lowMask); + c[1] = vec_sr(c[1], v4); + c[0] = vec_sub(c[0], v8); + c[1] = vec_sub(c[1], v8); + vsum = vec_sum4s(c[0], vsum); + vsum2 = vec_sum4s(c[1], vsum2); + vsum = vec_add(vsum, vsum2); + *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + } + + template + inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) { vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - vector signed int vsum = {0}; - vector signed int vsum2 = {0}; + V2 t1, t2, t3, t4, t5, t6, t7, t8; + vector unsigned char xor_vector; + uint8_t flip_vec = 0x80; + xor_vector = vec_splats(flip_vec); + t1 = vec_perm(s1, s2, swiz1); + t2 = vec_perm(s1, s2, swiz2); + t3 = vec_perm(s3, s4, swiz1); + t4 = vec_perm(s3, s4, swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset+16); + vec_xst(t7, 0, vecOffset+32); + vec_xst(t8, 0, vecOffset+48); + } + template + void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray) { + int64_t i, j; + TA *aoffset = NULL; + int8_t *vecOffset = NULL; + TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; + TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; + vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; + aoffset = const_cast(a); + vecOffset = vec; j = (rows >> 3); if (j > 0) { do { @@ -1646,159 +1683,30 @@ class tinyBLAS_Q0_PPC { aoffset7 = aoffset6 + lda; aoffset8 = aoffset7 + lda; aoffset += 8 * lda; - i = (cols >> 2); if (i > 0) { do { - c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); - c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); - c5[1] = reinterpret_cast(vec_xl(0, aoffset5->qs)); - c6[1] = reinterpret_cast(vec_xl(0, aoffset6->qs)); - c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs)); - c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs)); - - c1[0] = vec_and(c1[1], lowMask); - c1[1] = vec_sr(c1[1], v4); - c1[0] = vec_sub(c1[0], v8); - c1[1] = vec_sub(c1[1], v8); - vsum = vec_sum4s(c1[0], vsum); - vsum2 = vec_sum4s(c1[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c2[0] = vec_and(c2[1], lowMask); - c2[1] = vec_sr(c2[1], v4); - c2[0] = vec_sub(c2[0], v8); - c2[1] = vec_sub(c2[1], v8); - vsum = vec_sum4s(c2[0], vsum); - vsum2 = vec_sum4s(c2[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c3[0] = vec_and(c3[1], lowMask); - c3[1] = vec_sr(c3[1], v4); - c3[0] = vec_sub(c3[0], v8); - c3[1] = vec_sub(c3[1], v8); - vsum = vec_sum4s(c3[0], vsum); - vsum2 = vec_sum4s(c3[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c4[0] = vec_and(c4[1], lowMask); - c4[1] = vec_sr(c4[1], v4); - c4[0] = vec_sub(c4[0], v8); - c4[1] = vec_sub(c4[1], v8); - vsum = vec_sum4s(c4[0], vsum); - vsum2 = vec_sum4s(c4[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c5[0] = vec_and(c5[1], lowMask); - c5[1] = vec_sr(c5[1], v4); - c5[0] = vec_sub(c5[0], v8); - c5[1] = vec_sub(c5[1], v8); - vsum = vec_sum4s(c5[0], vsum); - vsum2 = vec_sum4s(c5[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c6[0] = vec_and(c6[1], lowMask); - c6[1] = vec_sr(c6[1], v4); - c6[0] = vec_sub(c6[0], v8); - c6[1] = vec_sub(c6[1], v8); - vsum = vec_sum4s(c6[0], vsum); - vsum2 = vec_sum4s(c6[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c7[0] = vec_and(c7[1], lowMask); - c7[1] = vec_sr(c7[1], v4); - c7[0] = vec_sub(c7[0], v8); - c7[1] = vec_sub(c7[1], v8); - vsum = vec_sum4s(c7[0], vsum); - vsum2 = vec_sum4s(c7[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c8[0] = vec_and(c8[1], lowMask); - c8[1] = vec_sr(c8[1], v4); - c8[0] = vec_sub(c8[0], v8); - c8[1] = vec_sub(c8[1], v8); - vsum = vec_sum4s(c8[0], vsum); - vsum2 = vec_sum4s(c8[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - - t1 = vec_perm(c5[0], c6[0], swiz1); - t2 = vec_perm(c5[0], c6[0], swiz2); - t3 = vec_perm(c7[0], c8[0], swiz1); - t4 = vec_perm(c7[0], c8[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+128); - vec_xst(t6, 0, vecOffset+144); - vec_xst(t7, 0, vecOffset+160); - vec_xst(t8, 0, vecOffset+176); - - t1 = vec_perm(c5[1], c6[1], swiz1); - t2 = vec_perm(c5[1], c6[1], swiz2); - t3 = vec_perm(c7[1], c8[1], swiz1); - t4 = vec_perm(c7[1], c8[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+192); - vec_xst(t6, 0, vecOffset+208); - vec_xst(t7, 0, vecOffset+224); - vec_xst(t8, 0, vecOffset+240); - + c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); + c5[1] = reinterpret_cast(vec_xl(0, aoffset5->qs)); + c6[1] = reinterpret_cast(vec_xl(0, aoffset6->qs)); + c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs)); + c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs)); + + process_q4_elements(c1, &comparray[0]); + process_q4_elements(c2, &comparray[1]); + process_q4_elements(c3, &comparray[2]); + process_q4_elements(c4, &comparray[3]); + process_q4_elements(c5, &comparray[4]); + process_q4_elements(c6, &comparray[5]); + process_q4_elements(c7, &comparray[6]); + process_q4_elements(c8, &comparray[7]); + vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); + vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false); + vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -1821,85 +1729,20 @@ class tinyBLAS_Q0_PPC { aoffset3 = aoffset2 + lda; aoffset4 = aoffset3 + lda; aoffset += 4 * lda; - i = (cols >> 2); if (i > 0) { do { - c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); - c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); - - c1[0] = vec_and(c1[1], lowMask); - c1[1] = vec_sr(c1[1], v4); - c1[0] = vec_sub(c1[0], v8); - c1[1] = vec_sub(c1[1], v8); - vsum = vec_sum4s(c1[0], vsum); - vsum2 = vec_sum4s(c1[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c2[0] = vec_and(c2[1], lowMask); - c2[1] = vec_sr(c2[1], v4); - c2[0] = vec_sub(c2[0], v8); - c2[1] = vec_sub(c2[1], v8); - vsum = vec_sum4s(c2[0], vsum); - vsum2 = vec_sum4s(c2[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c3[0] = vec_and(c3[1], lowMask); - c3[1] = vec_sr(c3[1], v4); - c3[0] = vec_sub(c3[0], v8); - c3[1] = vec_sub(c3[1], v8); - vsum = vec_sum4s(c3[0], vsum); - vsum2 = vec_sum4s(c3[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c4[0] = vec_and(c4[1], lowMask); - c4[1] = vec_sr(c4[1], v4); - c4[0] = vec_sub(c4[0], v8); - c4[1] = vec_sub(c4[1], v8); - vsum = vec_sum4s(c4[0], vsum); - vsum2 = vec_sum4s(c4[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats( 0); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - + c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); + + process_q4_elements(c1, &comparray[0]); + process_q4_elements(c2, &comparray[1]); + process_q4_elements(c3, &comparray[2]); + process_q4_elements(c4, &comparray[3]); + vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -1918,80 +1761,17 @@ class tinyBLAS_Q0_PPC { if (i > 0) { do { switch(rows) { - case 3: c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - case 2: c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + case 3: c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + case 2: c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); break; } - c1[0] = vec_and(c1[1], lowMask); - c1[1] = vec_sr(c1[1], v4); - c1[0] = vec_sub(c1[0], v8); - c1[1] = vec_sub(c1[1], v8); - vsum = vec_sum4s(c1[0], vsum); - vsum2 = vec_sum4s(c1[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c2[0] = vec_and(c2[1], lowMask); - c2[1] = vec_sr(c2[1], v4); - c2[0] = vec_sub(c2[0], v8); - c2[1] = vec_sub(c2[1], v8); - vsum = vec_sum4s(c2[0], vsum); - vsum2 = vec_sum4s(c2[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c3[0] = vec_and(c3[1], lowMask); - c3[1] = vec_sr(c3[1], v4); - c3[0] = vec_sub(c3[0], v8); - c3[1] = vec_sub(c3[1], v8); - vsum = vec_sum4s(c3[0], vsum); - vsum2 = vec_sum4s(c3[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c4[0] = vec_and(c4[1], lowMask); - c4[1] = vec_sr(c4[1], v4); - c4[0] = vec_sub(c4[0], v8); - c4[1] = vec_sub(c4[1], v8); - vsum = vec_sum4s(c4[0], vsum); - vsum2 = vec_sum4s(c4[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); + process_q4_elements(c1, &comparray[0]); + process_q4_elements(c2, &comparray[1]); + process_q4_elements(c3, &comparray[2]); + process_q4_elements(c4, &comparray[3]); + vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2001,146 +1781,40 @@ class tinyBLAS_Q0_PPC { } } } - template - void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { + void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { int64_t i, j; - TB *aoffset = NULL; + block_q8_0 *aoffset = NULL; VA *vecOffset = NULL; - TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; - VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0}; - VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0}; - VB t1, t2, t3, t4, t5, t6, t7, t8; - vector unsigned char xor_vector; - uint8_t flip_vec = 0x80; - xor_vector = vec_splats(flip_vec); - vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; - vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; - vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; - vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - - aoffset = const_cast(a); + block_q8_0* aoffsets[8]; + __vector_pair arr[8]; + VB c[8][2] = {0}; + VB c1[8] = {0}; VB c2[8] = {0}; + aoffset = const_cast(a); vecOffset = vec; j = (rows >> 3); if (j > 0) { do { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it < 8; it++) + aoffsets[it] = aoffsets[it-1] + lda; aoffset += 8 * lda; i = (cols >> 3); if (i > 0) { do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs); - C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs); - C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs); - C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs); - C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs); - - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - __builtin_vsx_disassemble_pair(c5, &C5); - __builtin_vsx_disassemble_pair(c6, &C6); - __builtin_vsx_disassemble_pair(c7, &C7); - __builtin_vsx_disassemble_pair(c8, &C8); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - - t1 = vec_perm(c5[0], c6[0], swiz1); - t2 = vec_perm(c5[0], c6[0], swiz2); - t3 = vec_perm(c7[0], c8[0], swiz1); - t4 = vec_perm(c7[0], c8[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset+128); - vec_xst(t6, 0, vecOffset+144); - vec_xst(t7, 0, vecOffset+160); - vec_xst(t8, 0, vecOffset+176); - - t1 = vec_perm(c5[1], c6[1], swiz1); - t2 = vec_perm(c5[1], c6[1], swiz2); - t3 = vec_perm(c7[1], c8[1], swiz1); - t4 = vec_perm(c7[1], c8[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + for (int it = 0; it < 8; it++) { + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], &arr[it]); + c1[it] = c[it][0]; + c2[it] = c[it][1]; } - vec_xst(t5, 0, vecOffset+192); - vec_xst(t6, 0, vecOffset+208); - vec_xst(t7, 0, vecOffset+224); - vec_xst(t8, 0, vecOffset+240); - - aoffset1 += lda; - aoffset2 += lda; - aoffset3 += lda; - aoffset4 += lda; - aoffset5 += lda; - aoffset6 += lda; - aoffset7 += lda; - aoffset8 += lda; + vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip); + vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip); + for (int it = 0; it < 8; it++) + aoffsets[it] += lda; vecOffset += 256; i--; } while(i > 0); @@ -2150,129 +1824,53 @@ class tinyBLAS_Q0_PPC { } if (rows & 4) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset += 4 * lda; - + aoffsets[0] = aoffset; + for (int it = 1; it < 4; it++ ) + aoffsets[it] = aoffsets[it-1] + lda; + aoffset += 4 * lda; i = (cols >> 3); if (i > 0) { do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs); - - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + for (int it = 0; it < 4; it++) { + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], &arr[it]); + c1[it] = c[it][0]; + c2[it] = c[it][1]; } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + for (int it = 0; it < 4; it++) { + aoffsets[it] += lda; } - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - - aoffset1 += lda; - aoffset2 += lda; - aoffset3 += lda; - aoffset4 += lda; vecOffset += 128; i--; } while(i > 0); } } + if (rows & 3) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it < 3; it++ ) + aoffsets[it] = aoffsets[it-1] + lda; i = (cols >> 3); if (i > 0) { do { switch(rows) { - case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); - __builtin_vsx_disassemble_pair(c3, &C3); - case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); - __builtin_vsx_disassemble_pair(c2, &C2); - case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); - __builtin_vsx_disassemble_pair(c1, &C1); + case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs); + __builtin_vsx_disassemble_pair(c[2], &arr[2]); + c1[2] = c[2][0]; c2[2] = c[2][1]; + case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs); + __builtin_vsx_disassemble_pair(c[1], &arr[1]); + c1[1] = c[1][0]; c2[1] = c[1][1]; + case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs); + __builtin_vsx_disassemble_pair(c[0], &arr[0]); + c1[0] = c[0][0]; c2[0] = c[0][1]; break; } - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - - aoffset1 += lda; - aoffset2 += lda; - aoffset3 += lda; + vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + for (int it = 0; it < 3; it++) + aoffsets[it] += lda; vecOffset += 128; i--; } while(i > 0); @@ -2281,159 +1879,42 @@ class tinyBLAS_Q0_PPC { } void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; - int m_rem = MIN(m - m0, 8); - int n_rem = MIN(n - n0, 8); - // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance - // issues. After resolving them, below code will be enabled. - /*if (m_rem >= 16 && n_rem >= 8) { - mc = 16; - nc = 8; - gemm<16,8>(m0, m, n0, n); - } else if(m_rem >= 8 && n_rem >= 16) { - mc = 8; - nc = 16; - gemm<8,16>(m0, m, n0, n); - }*/ + int m_rem = MIN(m - m0, 16); + int n_rem = MIN(n - n0, 16); + + int mc = 0, nc = 0; + if (m_rem >= 8 && n_rem >= 8) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); + mc = 8; + nc = 8; + gemm<8, 8>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 8) { mc = 4; nc = 8; - gemm<4,8>(m0, m, n0, n); + gemm<4, 8>(m0, m, n0, n); } else if (m_rem >= 8 && n_rem >= 4) { mc = 8; nc = 4; - gemm<8,4>(m0, m, n0, n); + gemm<8, 4>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 4) { mc = 4; nc = 4; - gemm_small<4, 4>(m0, m, n0, n); - } else if ((m_rem < 4) && (n_rem > 4)) { - nc = 4; - switch(m_rem) { - case 1: - mc = 1; - gemm_small<1, 4>(m0, m, n0, n); - break; - case 2: - mc = 2; - gemm_small<2, 4>(m0, m, n0, n); - break; - case 3: - mc = 3; - gemm_small<3, 4>(m0, m, n0, n); - break; - default: - return; - } - } else if ((m_rem > 4) && (n_rem < 4)) { - mc = 4; - switch(n_rem) { - case 1: - nc = 1; - gemm_small<4, 1>(m0, m, n0, n); - break; - case 2: - nc = 2; - gemm_small<4, 2>(m0, m, n0, n); - break; - case 3: - nc = 3; - gemm_small<4, 3>(m0, m, n0, n); - break; - default: - return; - } + gemm_small(m0, m, n0, n, mc, nc); } else { - switch((m_rem << 4) | n_rem) { - case 0x43: - mc = 4; - nc = 3; - gemm_small<4, 3>(m0, m, n0, n); - break; - case 0x42: - mc = 4; - nc = 2; - gemm_small<4, 2>(m0, m, n0, n); - break; - case 0x41: - mc = 4; - nc = 1; - gemm_small<4, 1>(m0, m, n0, n); - break; - case 0x34: - mc = 3; - nc = 4; - gemm_small<3, 4>(m0, m, n0, n); - break; - case 0x33: - mc = 3; - nc = 3; - gemm_small<3, 3>(m0, m, n0, n); - break; - case 0x32: - mc = 3; - nc = 2; - gemm_small<3, 2>(m0, m, n0, n); - break; - case 0x31: - mc = 3; - nc = 1; - gemm_small<3, 1>(m0, m, n0, n); - break; - case 0x24: - mc = 2; - nc = 4; - gemm_small<2, 4>(m0, m, n0, n); - break; - case 0x23: - mc = 2; - nc = 3; - gemm_small<2, 3>(m0, m, n0, n); - break; - case 0x22: - mc = 2; - nc = 2; - gemm_small<2, 2>(m0, m, n0, n); - break; - case 0x21: - mc = 2; - nc = 1; - gemm_small<2, 1>(m0, m, n0, n); - break; - case 0x14: - mc = 1; - nc = 4; - gemm_small<1, 4>(m0, m, n0, n); - break; - case 0x13: - mc = 1; - nc = 3; - gemm_small<1, 3>(m0, m, n0, n); - break; - case 0x12: - mc = 1; - nc = 2; - gemm_small<1, 2>(m0, m, n0, n); - break; - case 0x11: - mc = 1; - nc = 1; - gemm_small<1, 1>(m0, m, n0, n); - break; - default: - return; - } + mc = (m_rem >= 4) ? 4 : m_rem; + nc = (n_rem >= 4) ? 4 : n_rem; + if (mc == 0 || nc == 0) + return; + gemm_small(m0, m, n0, n, mc, nc); } - mp = m0 + (m - m0) / mc * mc; - np = n0 + (n - n0) / nc * nc; + + int64_t mp = m0 + ((m - m0) / mc) * mc; + int64_t np = n0 + ((n - n0) / nc) * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); } + void KERNEL_4x8(int64_t ii, int64_t jj) { vec_t vec_A[8], vec_B[16] = {0}; acc_t acc_0, acc_1; @@ -2445,9 +1926,9 @@ class tinyBLAS_Q0_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); if (std::is_same_v) { - packNormalInt4((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray); } else { - packNormal((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); } packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { @@ -2475,8 +1956,8 @@ class tinyBLAS_Q0_PPC { compute<4>(&acc_0, 0, 0, comparray, vs, fin_res); compute<4>(&acc_1, 0, 4, comparray, vs, fin_res); } - save_res<4, 4>(ii, jj, 0, fin_res); - save_res<4, 4>(ii, jj+4, 4, fin_res); + save_res(ii, jj, 0, fin_res); + save_res(ii, jj+4, 4, fin_res); } void KERNEL_8x4(int64_t ii, int64_t jj) { @@ -2490,9 +1971,9 @@ class tinyBLAS_Q0_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); if (std::is_same_v) { - packNormalInt4((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); } else { - packNormal((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); } packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { @@ -2519,8 +2000,8 @@ class tinyBLAS_Q0_PPC { compute<8>(&acc_0, 0, 0, comparray, vs, fin_res); compute<8>(&acc_1, 4, 4, comparray, vs, fin_res); } - save_res<4, 4>(ii, jj, 0, fin_res); - save_res<4, 4>(ii+4, jj, 4, fin_res); + save_res(ii, jj, 0, fin_res); + save_res(ii+4, jj, 4, fin_res); } void KERNEL_8x8(int64_t ii, int64_t jj) { @@ -2536,9 +2017,9 @@ class tinyBLAS_Q0_PPC { __builtin_mma_xxsetaccz(&acc_2); __builtin_mma_xxsetaccz(&acc_3); if (std::is_same_v) { - packNormalInt4((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); } else { - packNormal((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); } packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { @@ -2570,14 +2051,13 @@ class tinyBLAS_Q0_PPC { compute<8>(&acc_2, 0, 8, comparray, vs, fin_res); compute<8>(&acc_3, 4, 12, comparray, vs, fin_res); } - save_res<4, 4>(ii, jj, 0, fin_res); - save_res<4, 4>(ii+4, jj, 4, fin_res); - save_res<4, 4>(ii, jj+4, 8, fin_res); - save_res<4, 4>(ii+4, jj+4, 12, fin_res); + save_res(ii, jj, 0, fin_res); + save_res(ii+4, jj, 4, fin_res); + save_res(ii, jj+4, 8, fin_res); + save_res(ii+4, jj+4, 12, fin_res); } - template - void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) { + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2606,9 +2086,9 @@ class tinyBLAS_Q0_PPC { __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead __builtin_mma_xxsetaccz(&acc_0); if (isAblock_q4) { - packNormalInt4((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray); } else { - packNormal((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); } packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x+=4) { @@ -2641,7 +2121,7 @@ class tinyBLAS_Q0_PPC { fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]); } } - save_res(ii, jj, 0, fin_res); + save_res(ii, jj, 0, fin_res, RM, RN); } } @@ -2654,7 +2134,7 @@ class tinyBLAS_Q0_PPC { } else if constexpr(RM == 8 && RN == 8) { KERNEL_8x8(ii,jj); } else { - static_assert(false, "RN/RM values not supported"); + assert(false && "RN/RM values not supported"); } } @@ -2676,10 +2156,8 @@ class tinyBLAS_Q0_PPC { } const TA *const A; - const TB *const B; - TC *C; - TA *At; - TB *Bt; + const block_q8_0 *const B; + float *C; const int64_t k; const int64_t lda; const int64_t ldb; @@ -2688,13 +2166,12 @@ class tinyBLAS_Q0_PPC { const int nth; }; -template class tinyBLAS_PPC { public: tinyBLAS_PPC(int64_t k, - const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, + const float *A, int64_t lda, + const float *B, int64_t ldb, + float *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } @@ -2707,247 +2184,139 @@ class tinyBLAS_PPC { void (tinyBLAS_PPC::*kernel)(int64_t, int64_t); - template - void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) { + inline void vector_permute_store_4(vector float *src, float *vecOffset) { + vector float t1, t2, t3, t4, t5, t6, t7, t8; + t1 = vec_mergeh(src[0], src[1]); + t2 = vec_mergeh(src[2], src[3]); + t3 = vec_mergel(src[0], src[1]); + t4 = vec_mergel(src[2], src[3]); + + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t1, t2, 3); + t7 = vec_xxpermdi(t3, t4, 0); + t8 = vec_xxpermdi(t3, t4, 3); + + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset + 4); + vec_xst(t7, 0, vecOffset + 8); + vec_xst(t8, 0, vecOffset + 12); + } + + inline void vector_permute_store_8(vector float *src, float *vecOffset) { + vector float t1, t2, t3, t4, t5, t6, t7, t8; + t1 = vec_mergeh(src[0], src[1]); + t2 = vec_mergeh(src[2], src[3]); + t3 = vec_mergeh(src[4], src[5]); + t4 = vec_mergeh(src[6], src[7]); + + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset + 4); + vec_xst(t7, 0, vecOffset + 8); + vec_xst(t8, 0, vecOffset + 12); + + t1 = vec_mergel(src[0], src[1]); + t2 = vec_mergel(src[2], src[3]); + t3 = vec_mergel(src[4], src[5]); + t4 = vec_mergel(src[6], src[7]); + + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + + vec_xst(t5, 0, vecOffset + 16); + vec_xst(t6, 0, vecOffset + 20); + vec_xst(t7, 0, vecOffset + 24); + vec_xst(t8, 0, vecOffset + 28); + } + + void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) { int64_t i, j; - TA *aoffset = NULL, *boffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; - VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; - VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - VA t1, t2, t3, t4, t5, t6, t7, t8; - aoffset = const_cast(a); + float * aoffsets[8]; + float *aoffset = NULL, *boffset = NULL; + __vector_pair arr[8]; + vector float c[8][2] = {0}; + vector float c1[8] = {0}; + vector float c2[8] = {0}; + aoffset = const_cast(a); boffset = vec; j = (rows >> 3); if (j > 0) { do { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it< 8; it++) + aoffsets[it] = aoffsets[it-1] + lda; aoffset += 8 * lda; i = (cols >> 3); if (i > 0) { do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4); - C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5); - C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6); - C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7); - C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8); - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - __builtin_vsx_disassemble_pair(c5, &C5); - __builtin_vsx_disassemble_pair(c6, &C6); - __builtin_vsx_disassemble_pair(c7, &C7); - __builtin_vsx_disassemble_pair(c8, &C8); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_mergeh(c5[0], c6[0]); - t4 = vec_mergeh(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset); - vec_xst(t6, 0, boffset+4); - vec_xst(t7, 0, boffset+8); - vec_xst(t8, 0, boffset+12); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_mergel(c5[0], c6[0]); - t4 = vec_mergel(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+16); - vec_xst(t6, 0, boffset+20); - vec_xst(t7, 0, boffset+24); - vec_xst(t8, 0, boffset+28); - - t1 = vec_mergeh(c1[1], c2[1]); - t2 = vec_mergeh(c3[1], c4[1]); - t3 = vec_mergeh(c5[1], c6[1]); - t4 = vec_mergeh(c7[1], c8[1]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+32); - vec_xst(t6, 0, boffset+36); - vec_xst(t7, 0, boffset+40); - vec_xst(t8, 0, boffset+44); - - t1 = vec_mergel(c1[1], c2[1]); - t2 = vec_mergel(c3[1], c4[1]); - t3 = vec_mergel(c5[1], c6[1]); - t4 = vec_mergel(c7[1], c8[1]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+48); - vec_xst(t6, 0, boffset+52); - vec_xst(t7, 0, boffset+56); - vec_xst(t8, 0, boffset+60); - - aoffset1 += 8*lda; - aoffset2 += 8*lda; - aoffset3 += 8*lda; - aoffset4 += 8*lda; + for (int it = 0; it< 8; it++) { + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]); + __builtin_vsx_disassemble_pair(c[it], &arr[it]); + c1[it] = c[it][0]; + c2[it] = c[it][1]; + } + + vector_permute_store_8(c1, boffset); + vector_permute_store_8(c2, boffset+32); + for (int it = 0; it < 4; it++) + aoffsets[it] = aoffsets[it] + 8*lda; boffset += 64; i--; } while(i > 0); } if (cols & 4) { - c1[0] = vec_xl(0, aoffset1); - c2[0] = vec_xl(0, aoffset2); - c3[0] = vec_xl(0, aoffset3); - c4[0] = vec_xl(0, aoffset4); - c5[0] = vec_xl(0, aoffset5); - c6[0] = vec_xl(0, aoffset6); - c7[0] = vec_xl(0, aoffset7); - c8[0] = vec_xl(0, aoffset8); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_mergeh(c5[0], c6[0]); - t4 = vec_mergeh(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset); - vec_xst(t6, 0, boffset+4); - vec_xst(t7, 0, boffset+8); - vec_xst(t8, 0, boffset+12); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_mergel(c5[0], c6[0]); - t4 = vec_mergel(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+16); - vec_xst(t6, 0, boffset+20); - vec_xst(t7, 0, boffset+24); - vec_xst(t8, 0, boffset+28); + for (int it = 0; it < 8 ; it++) + c1[it] = vec_xl(0, aoffsets[it]); + vector_permute_store_8(c1, boffset); } j--; } while(j > 0); } if (rows & 4) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it < 4; it++) + aoffsets[it] = aoffsets[it-1] + lda; aoffset += 4 * lda; i = (cols >> 3); if (i > 0) { do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4); - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_mergel(c1[0], c2[0]); - t4 = vec_mergel(c3[0], c4[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t1, t2, 3); - t7 = vec_xxpermdi(t3, t4, 0); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset); - vec_xst(t6, 0, boffset+4); - vec_xst(t7, 0, boffset+8); - vec_xst(t8, 0, boffset+12); - - t1 = vec_mergeh(c1[1], c2[1]); - t2 = vec_mergeh(c3[1], c4[1]); - t3 = vec_mergel(c1[1], c2[1]); - t4 = vec_mergel(c3[1], c4[1]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t1, t2, 3); - t7 = vec_xxpermdi(t3, t4, 0); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+16); - vec_xst(t6, 0, boffset+20); - vec_xst(t7, 0, boffset+24); - vec_xst(t8, 0, boffset+28); - - aoffset1 += 8*lda; - aoffset2 += 8*lda; - aoffset3 += 8*lda; - aoffset4 += 8*lda; + for (int it = 0; it < 4; it++) { + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]); + __builtin_vsx_disassemble_pair(c[it], &arr[it]); + c1[it] = c[it][0]; + c2[it] = c[it][1]; + } + vector_permute_store_4(c1, boffset); + vector_permute_store_4(c2, boffset+16); + for (int it = 0; it < 4; it++) + aoffsets[it] += 8*lda; boffset += 32; i--; } while(i > 0); } if (cols & 4) { - c1[0] = vec_xl(0, aoffset1); - c2[0] = vec_xl(0, aoffset2); - c3[0] = vec_xl(0, aoffset3); - c4[0] = vec_xl(0, aoffset4); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset); - vec_xst(t4, 0, boffset+4); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset+8); - vec_xst(t4, 0, boffset+12); + for (int it = 0; it < 4; it++) + c1[it] = vec_xl(0, aoffsets[it]); + vector_permute_store_4(c1, boffset); } } if (rows & 3) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it < 3; it++) + aoffsets[it] = aoffsets[it-1] + lda; if (cols & 4) { - c1[0] = vec_xl(0, aoffset1); - c2[0] = vec_xl(0, aoffset2); - c3[0] = vec_xl(0, aoffset3); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset); - vec_xst(t4, 0, boffset+4); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset+8); - vec_xst(t4, 0, boffset+12); + for (int it = 0; it < 3; it++) + c1[it] = vec_xl(0, aoffsets[it]); + vector_permute_store_4(c1, boffset); } } } @@ -2957,8 +2326,8 @@ class tinyBLAS_PPC { acc_t acc_0; __builtin_mma_xxsetaccz(&acc_0); for (int l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]); @@ -2973,8 +2342,8 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); for (int64_t l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B); __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]); __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]); @@ -2994,8 +2363,8 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); for (int64_t l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]); @@ -3017,8 +2386,8 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_2); __builtin_mma_xxsetaccz(&acc_3); for (int l = 0; l < k; l+=8) { - packTranspose(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B); for(int x = 0; x < 16; x+=2) { __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]); __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]); @@ -3033,155 +2402,37 @@ class tinyBLAS_PPC { } void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; - int m_rem = MIN(m - m0, 16); - int n_rem = MIN(n - n0, 16); - if (m_rem >= 16 && n_rem >= 8) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); - } else if(m_rem >= 8 && n_rem >= 16) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); - } else if (m_rem >= 8 && n_rem >= 8) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); + int m_rem = MIN(m - m0, 8); + int n_rem = MIN(n - n0, 8); + int mc = 0, nc = 0; + if (m_rem >= 8 && n_rem >= 8) { + mc = 8; + nc = 8; + gemm<8, 8>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 8) { - mc = 4; - nc = 8; - gemm<4,8>(m0, m, n0, n); + mc = 4; + nc = 8; + gemm<4, 8>(m0, m, n0, n); } else if (m_rem >= 8 && n_rem >= 4) { - mc = 8; - nc = 4; - gemm<8,4>(m0, m, n0, n); + mc = 8; + nc = 4; + gemm<8, 4>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 4) { - mc = 4; - nc = 4; - gemm<4,4>(m0, m, n0, n); - } else if ((m_rem < 4) && (n_rem > 4)) { - nc = 4; - switch(m_rem) { - case 1: - mc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 2: - mc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 3: - mc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - default: - return; - } - } else if ((m_rem > 4) && (n_rem < 4)) { - mc = 4; - switch(n_rem) { - case 1: - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 2: - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 3: - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - default: - return; - } + mc = 4; + nc = 4; + gemm<4, 4>(m0, m, n0, n); } else { - switch((m_rem << 4) | n_rem) { - case 0x43: - mc = 4; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x42: - mc = 4; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x41: - mc = 4; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x34: - mc = 3; - nc = 4; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x33: - mc = 3; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x32: - mc = 3; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x31: - mc = 3; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x24: - mc = 2; - nc = 4; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x23: - mc = 2; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x22: - mc = 2; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x21: - mc = 2; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x14: - mc = 1; - nc = 4; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x13: - mc = 1; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x12: - mc = 1; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x11: - mc = 1; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - default: - return; - } + mc = (m_rem >= 4) ? 4 : m_rem; + nc = (n_rem >= 4) ? 4 : n_rem; + if (mc == 0 || nc == 0) + return; + gemm_small(m0, m, n0, n, mc, nc); } - mp = m0 + (m - m0) / mc * mc; - np = n0 + (n - n0) / nc * nc; + int64_t mp = m0 + ((m - m0) / mc) * mc; + int64_t np = n0 + ((n - n0) / nc) * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); - } + } void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { int64_t ytiles = (m - m0) / RM; @@ -3206,22 +2457,22 @@ class tinyBLAS_PPC { * matrix elements. */ if (RM == 1) { - TA* a = const_cast(A+(ii)*lda+l); - packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); + float* a = const_cast(A+(ii)*lda+l); + packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B); vec_A[0] = (vec_t)vec_xl(0,a); - vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1)); - vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2)); - vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3)); + vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1)); + vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2)); + vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3)); } else if (RN == 1) { - packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); - TB* b = const_cast(B+(jj)*ldb+l); + packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A); + float* b = const_cast(B+(jj)*ldb+l); vec_B[0] = (vec_t)vec_xl(0,b); - vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1)); - vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2)); - vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3)); + vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1)); + vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2)); + vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3)); } else { - packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B); } __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); @@ -3231,7 +2482,7 @@ class tinyBLAS_PPC { __builtin_mma_disassemble_acc(vec_C, &acc_0); for (int I = 0; I < RM; I++) { for (int J = 0; J < RN; J++) { - *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J); + *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); } } } @@ -3263,11 +2514,9 @@ class tinyBLAS_PPC { } } - const TA *const A; - const TB *const B; - TC *C; - TA *At; - TB *Bt; + const float *const A; + const float *const B; + float *C; const int64_t k; const int64_t lda; const int64_t ldb; @@ -3366,7 +2615,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 #elif defined(__MMA__) if (k % 8) return false; - tinyBLAS_PPC tb{ + tinyBLAS_PPC tb{ k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, @@ -3493,7 +2742,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 return false; if (m < 8 && m != 4) return false; - tinyBLAS_Q0_PPC tb{ + tinyBLAS_Q0_PPC tb{ k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, @@ -3530,7 +2779,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 return false; if (m < 8 && m != 4) return false; - tinyBLAS_Q0_PPC tb{ + tinyBLAS_Q0_PPC tb{ k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, From 72dae6b9f21561120a073da9f2b34a56ff6ccbdc Mon Sep 17 00:00:00 2001 From: Anton Mitkov Date: Mon, 14 Jul 2025 18:12:42 +0100 Subject: [PATCH 09/20] sycl: Hotfix for non dnnl codepath (llama/14677) --- ggml/src/ggml-sycl/ggml-sycl.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index cf46012be81..a6f9af0c86e 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2875,12 +2875,20 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons } } +#if GGML_SYCL_DNNL + // oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1; src1_f16_alloc.alloc(ne_src1); - const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); GGML_ASSERT(to_fp16_sycl != nullptr); to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue); +# else + const int64_t ne_src1 = ggml_nelements(src1); + src1_f16_alloc.alloc(ne_src1); + const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type); + GGML_ASSERT(to_fp16_nc_sycl != nullptr); + to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue); +#endif src1_f16 = src1_f16_alloc.get(); s11 = ne10; From 591bc24690c7422ed3962970a8f9741a77a1da3c Mon Sep 17 00:00:00 2001 From: R0CKSTAR Date: Tue, 15 Jul 2025 15:28:53 +0800 Subject: [PATCH 10/20] cuda: fix build warnings in set-rows.cu (unused variable) (llama/14687) Signed-off-by: Xiaodong Ye --- ggml/src/ggml-cuda/set-rows.cu | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 3fade72b84e..58cee924401 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -3,7 +3,10 @@ typedef void (*set_rows_kernel_t)(const char * src, char * dst); template -__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {} +__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) { + GGML_UNUSED(src_f); + GGML_UNUSED(dst_f); +} template<> __device__ __forceinline__ void set_rows_1(const float * src_f, half * dst_h) { @@ -53,6 +56,9 @@ static __global__ void k_set_rows( const src_t* src_elem = src0_row + i00; dst_t* dst_elem = dst_row_ptr + i00; set_rows_1(src_elem, dst_elem); + + GGML_UNUSED(ne10); + GGML_UNUSED(ne13); } template From b2653a91f686ece3800c91b799bfbc02ab0a30b9 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 15 Jul 2025 14:32:11 -0500 Subject: [PATCH 11/20] vulkan: add RTE variants for glu/add/sub/mul/div (llama/14653) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 12 ++- .../vulkan-shaders/copy_to_quant.comp | 6 +- .../vulkan-shaders/generic_binary_head.comp | 2 + .../ggml-vulkan/vulkan-shaders/glu_head.comp | 2 + .../ggml-vulkan/vulkan-shaders/im2col.comp | 5 +- .../ggml-vulkan/vulkan-shaders/rope_head.comp | 5 +- ggml/src/ggml-vulkan/vulkan-shaders/rte.comp | 5 ++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 85 +++++++++++++++---- 8 files changed, 90 insertions(+), 32 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/rte.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 416ee3bd3f7..9f5646bf29d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2835,10 +2835,11 @@ static void ggml_vk_load_shaders(vk_device& device) { return s; }; + bool rte = device->float_controls_rte_fp16; #define CREATE_BINARY(name, namemod, spec) \ for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ - #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \ + #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \ "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); CREATE_BINARY(add, , {0}) @@ -2890,8 +2891,13 @@ static void ggml_vk_load_shaders(vk_device& device) { #undef CREATE_UNARY #define CREATE_GLU(name) \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); + if (device->float_controls_rte_fp16) { \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + } else { \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + } CREATE_GLU(geglu) CREATE_GLU(reglu) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index e06547e48f7..27d6b7464f6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -1,10 +1,6 @@ #version 450 -#if RTE16 -#extension GL_EXT_spirv_intrinsics : enable -spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits -#endif // RTE16 - +#include "rte.comp" #include "types.comp" #if defined(SET_ROWS) && QUANT_K == 1 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp index 062e2a4cdf2..4b4316cf3d9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp @@ -1,6 +1,8 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_control_flow_attributes : require +#include "rte.comp" + layout (push_constant) uniform parameter { uint ne; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp index 41a29889075..004a61fc162 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp @@ -1,5 +1,7 @@ #extension GL_EXT_shader_16bit_storage : require +#include "rte.comp" + layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 09aa849e881..17c7ccb90d0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -1,12 +1,9 @@ #version 450 #extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_spirv_intrinsics: enable #extension GL_EXT_control_flow_attributes : require -#if RTE16 -spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits -#endif +#include "rte.comp" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp index 96c9c4cbd30..00e203e73bd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp @@ -1,11 +1,8 @@ #include "types.comp" #extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_spirv_intrinsics: enable -#if RTE16 -spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits -#endif +#include "rte.comp" layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp new file mode 100644 index 00000000000..ad51c1e80b8 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp @@ -0,0 +1,5 @@ + +#if RTE16 +#extension GL_EXT_spirv_intrinsics : enable +spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits +#endif // RTE16 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index d4a4e4c5290..809c0bd9bd3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -537,8 +537,10 @@ void process_shaders() { for (auto src0_f16 : {false, true}) { for (auto src1_f16 : {false, true}) { for (auto dst_f16 : {false, true}) { - auto name = op + get_suffix(src0_f16, src1_f16, dst_f16); - string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}}); + for (auto rte : {false, true}) { + auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : ""); + string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + } } } } @@ -592,16 +594,19 @@ void process_shaders() { string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + for (auto rte : {false, true}) { + std::string suffix = rte ? "_rte" : ""; + string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + } string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -709,11 +714,59 @@ void write_output_files() { std::remove(path.c_str()); } } + + std::string suffixes[2] = {"_f32", "_f16"}; for (const char *op : {"add", "sub", "mul", "div"}) { - fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op); - fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op); - fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op); - fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op); + fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op); + fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op); + std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = "; + std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = "; + for (uint32_t t0 = 0; t0 < 2; ++t0) { + if (t0 == 0) { + data += "{"; + len += "{"; + } + for (uint32_t t1 = 0; t1 < 2; ++t1) { + if (t1 == 0) { + data += "{"; + len += "{"; + } + for (uint32_t t2 = 0; t2 < 2; ++t2) { + if (t2 == 0) { + data += "{"; + len += "{"; + } + for (uint32_t rte = 0; rte < 2; ++rte) { + if (rte == 0) { + data += "{"; + len += "{"; + } + data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : ""); + len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : ""); + data += "_data,"; + len += "_len,"; + if (rte == 1) { + data += "}, "; + len += "}, "; + } + } + if (t2 == 1) { + data += "}, "; + len += "}, "; + } + } + if (t1 == 1) { + data += "}, "; + len += "}, "; + } + } + if (t0 == 1) { + data += "};\n"; + len += "};\n"; + } + } + fprintf(src, data.c_str()); + fprintf(src, len.c_str()); } fclose(hdr); fclose(src); From bf13b8239da739edaa6525f6a216695996a6a8aa Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 15 Jul 2025 14:51:09 -0500 Subject: [PATCH 12/20] vulkan: fix noncontig check for mat_mul_id splitting (llama/14683) * vulkan: fix noncontig check for mat_mul_id splitting Remove supports_op check for > 4096 (splitting fixes this) * vulkan: fix batched matmul dequant for Q*_K --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 6 +----- ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp | 2 +- 6 files changed, 6 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 9f5646bf29d..3019a545d58 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4922,7 +4922,7 @@ static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { return tensor->nb[0] == ggml_type_size(tensor->type) && tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; + (tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]); } static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) { @@ -10356,10 +10356,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm // If there's not enough shared memory for row_ids and the result tile, fallback to CPU return false; } - // Check against size of shared memory variable - if (op->src[2]->ne[0] > 4096) { - return false; - } } switch (src0_type) { case GGML_TYPE_F32: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp index 157154af3a3..d4e4e6bae63 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint i = gl_WorkGroupID.x * 256 + wgy; - if (i >= p.M * p.K / QUANT_K) { + if (i >= p.nel / QUANT_K) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp index c17dd0d9991..3661f771c74 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint i = uint(gl_WorkGroupID.x * 256 + wgy); - if (i >= p.M * p.K / QUANT_K) { + if (i >= p.nel / QUANT_K) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp index 987f113a35a..1370db3654d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint ib = gl_WorkGroupID.x * 256 + wgy; - if (ib >= p.M * p.K / QUANT_K) { + if (ib >= p.nel / QUANT_K) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp index 6db5403b661..3f3b839e118 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint ib = gl_WorkGroupID.x * 256 + wgy; - if (ib >= p.M * p.K / QUANT_K) { + if (ib >= p.nel / QUANT_K) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp index 0b91317550f..9cf34256e8c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint i = gl_WorkGroupID.x * 256 + wgy; - if (i >= p.M * p.K / QUANT_K) { + if (i >= p.nel / QUANT_K) { return; } const uint tid = gl_LocalInvocationID.x; From 26f3de16caba6bec41d84dcf198cefd227279a38 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 16 Jul 2025 14:43:32 +0300 Subject: [PATCH 13/20] ggml : add asserts (llama/14720) * ggml : add asserts ggml-ci * cont : fix constant type Co-authored-by: Diego Devesa --------- Co-authored-by: Diego Devesa --- ggml/src/ggml-cpu/ops.cpp | 3 +++ ggml/src/ggml-cpu/vec.cpp | 3 +++ 2 files changed, 6 insertions(+) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index fd77e9a6aba..6581d27adde 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4015,6 +4015,9 @@ static void ggml_compute_forward_rms_norm_f32( const float scale = 1.0f/sqrtf(mean + eps); + // if you hit this, likely you got an inf somewhere earlier + assert(scale > 0.0f); + ggml_vec_scale_f32(ne00, y, scale); } } diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index a8156011eba..07b377bdd82 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -221,6 +221,9 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G for (int i = np; i < n; ++i) { sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i])); } + + // if you hit this, you are likely running outside the FP range + assert(!isnan(sumf) && !isinf(sumf)); #else for (int i = 0; i < n; ++i) { sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i])); From 9b27887e3544b0f360a95a572aaf0917020a96e6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 16 Jul 2025 16:35:42 +0300 Subject: [PATCH 14/20] llama : add high-throughput mode (llama/14363) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * kv-cache : prepare K/V buffers for separation ggml-ci * batched-bench : fix oob write ggml-ci * llama : add "virtual sequences" ggml-ci * llama : use "stream" vs "virtual sequence" ggml-ci * graph : fix stream splitting when KV cache is not used ggml-ci * kv-cache : add multi-stream save/load support ggml-ci * llama : add "--attn-streams" flag ggml-ci * kv-cache : fix handling when find_slot fails ggml-ci * kv-cache : restore find_slot impl ggml-ci * kv-cache : add comments * kv-cache : add bounds checks for sequence id ggml-ci * cont : add n_seq_max to batch allocr ggml-ci * kv-cache : perform stream copies lazily after llama_synchronize ggml-ci * kv-cache : avoid throwing exceptions across the C boundary ggml-ci * CUDA: 4D FlashAttention support (llama/14628) * CUDA: 4D FlashAttention support * CUDA: fix WMMA FA kernel * llama : rename attn_streams -> kv_unified ggml-ci * common : rename kv_split -> kv_unified ggml-ci --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/fattn-common.cuh | 54 ++++++++++++++++++---------- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 40 +++++++++++---------- ggml/src/ggml-cuda/fattn-tile-f16.cu | 34 ++++++++++-------- ggml/src/ggml-cuda/fattn-tile-f32.cu | 30 +++++++++------- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 23 ++++++------ ggml/src/ggml-cuda/fattn-vec-f32.cuh | 27 +++++++------- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 24 ++++++++----- ggml/src/ggml-cuda/ggml-cuda.cu | 9 ++--- 8 files changed, 141 insertions(+), 100 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 075f14a49e9..9122fca6cf9 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) { constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup( const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; - const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; @@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup( return; } - const int channel = kbc0 / (iter_k*iter_j); - const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; + const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2)); + const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); + const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. if (jt*ncols1 + j >= ne01) { return; } - dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid; + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid; // Load the partial result that needs a fixup: float dst_val = 0.0f; @@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup( int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { - const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; @@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results( const float2 * __restrict__ VKQ_meta, float * __restrict__ dst, const int parallel_blocks) { - VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x; - VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x; - dst += D * gridDim.z*blockIdx.x; + // Dimension 0: threadIdx.x + // Dimension 1: blockIdx.x + // Dimension 2: blockIdx.y + // Dimension 3: blockIdx.z + // Memory layout is permuted with [0, 2, 1, 3] + + const int ne01 = gridDim.x; + const int ne02 = gridDim.y; + + const int col = blockIdx.x; + const int head = blockIdx.y; + const int sequence = blockIdx.z; + + const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head; + + VKQ_parts += j_dst_unrolled * parallel_blocks*D; + VKQ_meta += j_dst_unrolled * parallel_blocks; + dst += j_dst_unrolled * D; const int tid = threadIdx.x; __builtin_assume(tid < D); extern __shared__ float2 meta[]; for (int i = tid; i < 2*parallel_blocks; i += D) { - ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i]; + ((float *) meta)[i] = ((const float *)VKQ_meta) [i]; } __syncthreads(); @@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results( const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint32_t *) &KQ_max_scale) &= ftz_mask; - VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid]; + VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid]; VKQ_denominator += KQ_max_scale * meta[l].y; } - dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; + dst[tid] = VKQ_numerator / VKQ_denominator; } [[noreturn]] @@ -705,8 +723,6 @@ void launch_fattn( GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); - GGML_ASSERT(Q->ne[3] == 1); - ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); const int id = ggml_cuda_get_device(); @@ -853,8 +869,8 @@ void launch_fattn( scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, - mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, + mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, + mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0, Q->nb[1], Q->nb[2], Q->nb[3], nb11, nb12, nb13, nb21, nb22, nb23, @@ -869,11 +885,11 @@ void launch_fattn( flash_attn_stream_k_fixup <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); + const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]); const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); flash_attn_combine_results diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 709589854f0..6fa2e77299e 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -1274,8 +1276,8 @@ static __global__ void flash_attn_ext_f16( constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. // kbc == k block continuous, current index in continuous ijk space. - int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; - const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). @@ -1285,18 +1287,19 @@ static __global__ void flash_attn_ext_f16( int kb0_start = kbc % iter_k; int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); while (kbc < kbc_stop && kb0_stop == iter_k) { - const int channel = kbc / (iter_k*iter_j); - const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); + const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); - const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2)); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); + float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio)); - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; const int kb0_stop_kernel = kb0_stop * kb_niter; @@ -1325,18 +1328,19 @@ static __global__ void flash_attn_ext_f16( return; } - const int channel = kbc / (iter_k*iter_j); - const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); + const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); - const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2)); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); + float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio)); - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; const int kb0_stop_kernel = kb0_stop * kb_niter; diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 0c967f178e7..1f141328845 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -62,15 +64,17 @@ static __global__ void flash_attn_tile_ext_f16( const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); const int stride_KV2 = nb11 / sizeof(half2); - const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); @@ -255,6 +259,8 @@ static __global__ void flash_attn_tile_ext_f16( __syncthreads(); } + float2 * dst2 = (float2 *) dst; + #pragma unroll for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { const int j_VKQ = j_VKQ_0 + threadIdx.y; @@ -266,21 +272,21 @@ static __global__ void flash_attn_tile_ext_f16( half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); kqsum_j = warp_reduce_sum((float)kqsum_j); + const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + #pragma unroll - for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) { - const int i0 = i00 + 2*threadIdx.x; + for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { + const int i0 = i00 + threadIdx.x; - half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; + half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; if (gridDim.y == 1) { dst_val /= __half2half2(kqsum_j); } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; - dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val); - dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val); + dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val); } if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); + dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); } } #else @@ -290,8 +296,8 @@ static __global__ void flash_attn_tile_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 908c76dbdd2..a4965583cef 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -74,15 +76,17 @@ static __global__ void flash_attn_tile_ext_f32( const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); const int stride_KV2 = nb11 / sizeof(half2); - const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); @@ -265,6 +269,8 @@ static __global__ void flash_attn_tile_ext_f32( __syncthreads(); } + float2 * dst2 = (float2 *) dst; + #pragma unroll for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { const int j_VKQ = j_VKQ_0 + threadIdx.y; @@ -276,22 +282,22 @@ static __global__ void flash_attn_tile_ext_f32( float kqsum_j = kqsum[j_VKQ_0/nwarps]; kqsum_j = warp_reduce_sum(kqsum_j); + const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + #pragma unroll - for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) { - const int i0 = i00 + 2*threadIdx.x; + for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { + const int i0 = i00 + threadIdx.x; - float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; + float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; if (gridDim.y == 1) { dst_val.x /= kqsum_j; dst_val.y /= kqsum_j; } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; - dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x; - dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y; + dst2[j_dst_unrolled*(D/2) + i0] = dst_val; } if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); + dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); } } #else diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index e78fb181919..b2d469938ab 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f16( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -65,14 +67,16 @@ static __global__ void flash_attn_vec_ext_f16( const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb02* blockIdx.z + nb01*ic0; - K += nb12*(blockIdx.z / gqa_ratio); - V += nb22*(blockIdx.z / gqa_ratio); + Q += nb03*sequence + nb02* head + nb01*ic0; + K += nb13*sequence + nb12*(head / gqa_ratio); + V += nb23*sequence + nb22*(head / gqa_ratio); - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); @@ -330,12 +334,11 @@ static __global__ void flash_attn_vec_ext_f16( if (gridDim.y == 1) { dst_val /= kqsum[j_VKQ]; } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; - dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val; + dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val; } if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); + dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); @@ -344,8 +347,8 @@ static __global__ void flash_attn_vec_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index b2f1724c955..405b6f5106e 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f32( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -53,8 +55,8 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); @@ -77,14 +79,16 @@ static __global__ void flash_attn_vec_ext_f32( const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb02* blockIdx.z + nb01*ic0; - K += nb12*(blockIdx.z / gqa_ratio); - V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape + Q += nb03*sequence + nb02* head + nb01*ic0; + K += nb13*sequence + nb12*(head / gqa_ratio); + V += nb23*sequence + nb22*(head / gqa_ratio); - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); constexpr int nwarps = D / WARP_SIZE; @@ -326,12 +330,11 @@ static __global__ void flash_attn_vec_ext_f32( if (gridDim.y == 1) { dst_val /= kqsum[j_VKQ]; } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; - dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val; + dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val; } if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); + dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); @@ -340,8 +343,8 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index c95ca7b1f28..741b8781d29 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -47,8 +47,10 @@ static __global__ void flash_attn_ext_f16( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -95,17 +97,19 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0); - const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); const half2 * mask2 = (const half2 *) maskh; const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); - const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); const half2 slope2 = make_half2(slopef, slopef); @@ -400,7 +404,6 @@ static __global__ void flash_attn_ext_f16( if (ic0 + j_VKQ >= ne01) { return; } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; float KQ_rowsum_j; if (std::is_same::value) { @@ -409,6 +412,8 @@ static __global__ void flash_attn_ext_f16( KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); } + const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + #pragma unroll for (int i0 = 0; i0 < D; i0 += warp_size) { const int i = i0 + threadIdx.x; @@ -419,7 +424,7 @@ static __global__ void flash_attn_ext_f16( if (gridDim.y == 1) { dst_val /= KQ_rowsum_j; } - dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val; + dst[j_dst_unrolled*D + i] = dst_val; } if (gridDim.y == 1 || threadIdx.x != 0) { @@ -433,7 +438,7 @@ static __global__ void flash_attn_ext_f16( dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); } dst_meta_val.y = KQ_rowsum_j; - dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val; + dst_meta[j_dst_unrolled] = dst_meta_val; } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); @@ -442,7 +447,8 @@ static __global__ void flash_attn_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31); + GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8015b0d4e8d..778d5a48bd9 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3413,12 +3413,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (op->src[0]->ne[0] == 192) { return false; } - // TODO: support broadcast - // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but - // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505 - if (op->src[0]->ne[3] != 1) { - return false; - } if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { return false; } @@ -3431,6 +3425,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { return true; } + if (op->src[3] && op->src[3]->ne[2] != 1) { + return false; + } return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; } From c3f7c10edc0b34cf75ae4d3d49de0fb0c67cdf97 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 16 Jul 2025 08:18:51 -0700 Subject: [PATCH 15/20] ggml: Add initial WebGPU backend (llama/14521) * Minimal setup of webgpu backend with dawn. Just prints out the adapter and segfaults * Initialize webgpu device * Making progress on setting up the backend * Finish more boilerplate/utility functions * Organize file and work on alloc buffer * Add webgpu_context to prepare for actually running some shaders * Work on memset and add shader loading * Work on memset polyfill * Implement set_tensor as webgpu WriteBuffer, remove host_buffer stubs since webgpu doesn't support it * Implement get_tensor and buffer_clear * Finish rest of setup * Start work on compute graph * Basic mat mul working * Work on emscripten build * Basic WebGPU backend instructions * Use EMSCRIPTEN flag * Work on passing ci, implement 4d tensor multiplication * Pass thread safety test * Implement permuting for mul_mat and cpy * minor cleanups * Address feedback * Remove division by type size in cpy op * Fix formatting and add github action workflows for vulkan and metal (m-series) webgpu backends * Fix name * Fix macos dawn prefix path --- ggml/CMakeLists.txt | 3 +++ ggml/include/ggml-webgpu.h | 19 +++++++++++++++++++ ggml/src/CMakeLists.txt | 1 + ggml/src/ggml-backend-reg.cpp | 7 +++++++ 4 files changed, 30 insertions(+) create mode 100644 ggml/include/ggml-webgpu.h diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index eaba9c70469..de6d789c98a 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -181,6 +181,8 @@ option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug ou option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF) option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF) option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF) +option(GGML_WEBGPU "ggml: use WebGPU" OFF) +option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) @@ -270,6 +272,7 @@ set(GGML_PUBLIC_HEADERS include/ggml-rpc.h include/ggml-sycl.h include/ggml-vulkan.h + include/ggml-webgpu.h include/gguf.h) set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}") diff --git a/ggml/include/ggml-webgpu.h b/ggml/include/ggml-webgpu.h new file mode 100644 index 00000000000..65b8ed9bb66 --- /dev/null +++ b/ggml/include/ggml-webgpu.h @@ -0,0 +1,19 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define GGML_WEBGPU_NAME "WebGPU" + +// Needed for examples in ggml +GGML_BACKEND_API ggml_backend_t ggml_backend_webgpu_init(void); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_webgpu_reg(void); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 8760c2d35ec..0425fd60a94 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -370,6 +370,7 @@ ggml_add_backend(MUSA) ggml_add_backend(RPC) ggml_add_backend(SYCL) ggml_add_backend(Vulkan) +ggml_add_backend(WebGPU) ggml_add_backend(OpenCL) foreach (target ggml-base ggml) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 042ea77aca7..f0cdac31eae 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -45,6 +45,10 @@ #include "ggml-vulkan.h" #endif +#ifdef GGML_USE_WEBGPU +#include "ggml-webgpu.h" +#endif + #ifdef GGML_USE_OPENCL #include "ggml-opencl.h" #endif @@ -173,6 +177,9 @@ struct ggml_backend_registry { #ifdef GGML_USE_VULKAN register_backend(ggml_backend_vk_reg()); #endif +#ifdef GGML_USE_WEBGPU + register_backend(ggml_backend_webgpu_reg()); +#endif #ifdef GGML_USE_OPENCL register_backend(ggml_backend_opencl_reg()); #endif From 966462949547aa6e2a68140559a368296bcd890f Mon Sep 17 00:00:00 2001 From: Neo Zhang Jianyu Date: Fri, 18 Jul 2025 10:23:14 +0800 Subject: [PATCH 16/20] use max work group size for device to replace the magic number (llama/14732) --- ggml/src/ggml-sycl/ggml-sycl.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index a6f9af0c86e..872eb4b052d 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3530,8 +3530,11 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, SYCL_CHECK(CHECK_TRY_ERROR( stream->memset(dev_cur_src1_row.get(), 0, sizeof(int)))); + const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; + assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + { - sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u)); + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size)); sycl::range<3> grid_dims(1, n_ids, ids->ne[1]); sycl_launch(stream, [&](sycl::handler & cgh) { sycl::local_accessor src1_row_acc(cgh); @@ -3575,7 +3578,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); { - sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u)); + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size)); sycl::range<3> grid_dims(1, 1, num_src1_rows); sycl_launch(stream, [&](sycl::handler & cgh) { const char *__restrict dst_contiguous_get = From 7df67c29761fdb3d04e5ab7daa3b1a3bfc1c3bd3 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 18 Jul 2025 14:54:18 +0800 Subject: [PATCH 17/20] CUDA: set_rows + cpy.cu refactor (llama/14712) --- ggml/src/ggml-cuda/cpy-utils.cuh | 251 +++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/cpy.cu | 239 +---------------------------- ggml/src/ggml-cuda/ggml-cuda.cu | 5 +- ggml/src/ggml-cuda/set-rows.cu | 145 +++++++++++++++++- 4 files changed, 396 insertions(+), 244 deletions(-) create mode 100644 ggml/src/ggml-cuda/cpy-utils.cuh diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh new file mode 100644 index 00000000000..e7a0bd2f1a0 --- /dev/null +++ b/ggml/src/ggml-cuda/cpy-utils.cuh @@ -0,0 +1,251 @@ +#pragma once + +#include "ggml-common.h" + +static __device__ __forceinline__ void convert_f32_f32(const float * src, float * dst) { + *dst = *src; +} + +static __device__ __forceinline__ void convert_f32_f16(const float * src, half * dst) { + *dst = __float2half(*src); +} + +static __device__ __forceinline__ void convert_f32_bf16(const float * src, nv_bfloat16 * dst) { + *dst = *src; +} + +static __device__ __forceinline__ void convert_f16_f16(const half * src, half * dst) { + *dst = *src; +} + +static __device__ __forceinline__ void convert_f16_f32(const half * src, float * dst) { + *dst = *src; +} + +static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} + +static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) { + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK4_0; ++j) { + const float v = x[j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + vmax = v; + } + } + + const float d = vmax / -8; + const float id = d ? 1.0f/d : 0.0f; + + y->d = d; + + for (int j = 0; j < QK4_0/2; ++j) { + const float x0 = x[0 + j]*id; + const float x1 = x[QK4_0/2 + j]*id; + + const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f)); + + y->qs[j] = xi0; + y->qs[j] |= xi1 << 4; + } +} + +static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) { + float vmin = FLT_MAX; + float vmax = -FLT_MAX; + + for (int j = 0; j < QK4_1; ++j) { + const float v = x[j]; + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + } + + const float d = (vmax - vmin) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y->dm.x = d; + y->dm.y = vmin; + + for (int j = 0; j < QK4_1/2; ++j) { + const float x0 = (x[0 + j] - vmin)*id; + const float x1 = (x[QK4_1/2 + j] - vmin)*id; + + const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f)); + + y->qs[j] = xi0; + y->qs[j] |= xi1 << 4; + } +} + +static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) { + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK5_0; ++j) { + const float v = x[j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + vmax = v; + } + } + + const float d = vmax / -16; + const float id = d ? 1.0f/d : 0.0f; + + y->d = d; + + uint32_t qh = 0; + for (int j = 0; j < QK5_0/2; ++j) { + const float x0 = x[0 + j]*id; + const float x1 = x[QK5_0/2 + j]*id; + + const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f)); + + y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + memcpy(y->qh, &qh, sizeof(qh)); +} + +static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) { + float min = x[0]; + float max = x[0]; + + for (int j = 1; j < QK5_1; ++j) { + const float v = x[j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = d ? 1.0f/d : 0.0f; + + y->dm.x = d; + y->dm.y = min; + + uint32_t qh = 0; + for (int j = 0; j < QK5_1/2; ++j) { + const float x0 = (x[0 + j] - min)*id; + const float x1 = (x[QK5_1/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); + } + memcpy(y->qh, &qh, sizeof(qh)); +} + +static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = x[j]; + amax = fmaxf(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y->d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = x[j]*id; + y->qs[j] = roundf(x0); + } +} + +static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) { + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK4_NL; ++j) { + const float v = x[j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + vmax = v; + } + } + + float d = vmax / kvalues_iq4nl[0]; + const float id = d ? 1.0f/d : 0.0f; + + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + const float x0 = x[0 + j]*id; + const float x1 = x[QK4_NL/2 + j]*id; + const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0); + const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1); + y->qs[j] = xi0 | (xi1 << 4); + const float v0 = kvalues_iq4nl[xi0]; + const float v1 = kvalues_iq4nl[xi1]; + const float w0 = x[0 + j]*x[0 + j]; + const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j]; + sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + } + + y->d = sumq2 > 0 ? sumqx/sumq2 : d; +} + +// Wrapper functions for cpy.cu compatibility +static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { + quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti); +} + +static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) { + quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti); +} + +static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) { + quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti); +} + +static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) { + quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti); +} + +static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { + quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti); +} + +static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { + quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti); +} + +static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { + convert_f32_f32((const float *)cxi, (float *)cdsti); +} + +static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { + convert_f32_f16((const float *)cxi, (half *)cdsti); +} + +static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) { + convert_f32_bf16((const float *)cxi, (nv_bfloat16 *)cdsti); +} + +static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) { + convert_f16_f16((const half *)cxi, (half *)cdsti); +} + +static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) { + convert_f16_f32((const half *)cxi, (float *)cdsti); +} diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 2c55d2149b2..e7d0da08705 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -1,46 +1,12 @@ #include "cpy.cuh" #include "dequantize.cuh" +#include "cpy-utils.cuh" #ifdef GGML_USE_MUSA #include "ggml-musa/mudnn.cuh" #endif // GGML_USE_MUSA typedef void (*cpy_kernel_t)(const char * cx, char * cdst); -static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - float * dsti = (float *) cdsti; - - *dsti = *xi; -} - -static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti; - - *dsti = *xi; -} - -static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - half * dsti = (half *) cdsti; - - *dsti = __float2half(*xi); -} - -static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) { - const half * xi = (const half *) cxi; - half * dsti = (half *) cdsti; - - *dsti = *xi; -} - -static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) { - const half * xi = (const half *) cxi; - float * dsti = (float *) cdsti; - - *dsti = *xi; -} - template static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -71,29 +37,6 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in cpy_1(cx + x_offset, cdst + dst_offset); } -static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q8_0 * dsti = (block_q8_0 *) cdsti; - - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - const float v = xi[j]; - amax = fmaxf(amax, fabsf(v)); - } - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dsti->d = d; - - for (int j = 0; j < QK8_0; ++j) { - const float x0 = xi[j]*id; - - dsti->qs[j] = roundf(x0); - } -} - static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { float * cdstf = (float *)(cdsti); @@ -106,139 +49,6 @@ static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { } } -static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q4_0 * dsti = (block_q4_0 *) cdsti; - - float amax = 0.0f; - float vmax = 0.0f; - - for (int j = 0; j < QK4_0; ++j) { - const float v = xi[j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - vmax = v; - } - } - - const float d = vmax / -8; - const float id = d ? 1.0f/d : 0.0f; - - dsti->d = d; - - for (int j = 0; j < QK4_0/2; ++j) { - const float x0 = xi[0 + j]*id; - const float x1 = xi[QK4_0/2 + j]*id; - - const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f)); - const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f)); - - dsti->qs[j] = xi0; - dsti->qs[j] |= xi1 << 4; - } -} - -static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q4_1 * dsti = (block_q4_1 *) cdsti; - - float vmin = FLT_MAX; - float vmax = -FLT_MAX; - - for (int j = 0; j < QK4_1; ++j) { - const float v = xi[j]; - - if (v < vmin) vmin = v; - if (v > vmax) vmax = v; - } - - const float d = (vmax - vmin) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dsti->dm.x = d; - dsti->dm.y = vmin; - - for (int j = 0; j < QK4_1/2; ++j) { - const float x0 = (xi[0 + j] - vmin)*id; - const float x1 = (xi[QK4_1/2 + j] - vmin)*id; - - const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f)); - const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f)); - - dsti->qs[j] = xi0; - dsti->qs[j] |= xi1 << 4; - } -} - -static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q5_0 * dsti = (block_q5_0 *) cdsti; - - float amax = 0.0f; - float vmax = 0.0f; - - for (int j = 0; j < QK5_0; ++j) { - const float v = xi[j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - vmax = v; - } - } - - const float d = vmax / -16; - const float id = d ? 1.0f/d : 0.0f; - - dsti->d = d; - - uint32_t qh = 0; - for (int j = 0; j < QK5_0/2; ++j) { - const float x0 = xi[0 + j]*id; - const float x1 = xi[QK5_0/2 + j]*id; - - const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f)); - const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f)); - - dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); - } - memcpy(dsti->qh, &qh, sizeof(qh)); -} - -static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q5_1 * dsti = (block_q5_1 *) cdsti; - - float min = xi[0]; - float max = xi[0]; - - for (int j = 1; j < QK5_1; ++j) { - const float v = xi[j]; - min = v < min ? v : min; - max = v > max ? v : max; - } - - const float d = (max - min) / 31; - const float id = d ? 1.0f/d : 0.0f; - - dsti->dm.x = d; - dsti->dm.y = min; - - uint32_t qh = 0; - for (int j = 0; j < QK5_1/2; ++j) { - const float x0 = (xi[0 + j] - min)*id; - const float x1 = (xi[QK5_1/2 + j] - min)*id; - - const uint8_t xi0 = (uint8_t)(x0 + 0.5f); - const uint8_t xi1 = (uint8_t)(x1 + 0.5f); - - dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); - } - memcpy(dsti->qh, &qh, sizeof(qh)); -} - template static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { float * cdstf = (float *)(cdsti); @@ -252,53 +62,6 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { } } -static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) { - if (x <= val[0]) return 0; - if (x >= val[n-1]) return n-1; - int ml = 0, mu = n-1; - while (mu-ml > 1) { - int mav = (ml+mu)/2; - if (x < val[mav]) mu = mav; else ml = mav; - } - return x - val[mu-1] < val[mu] - x ? mu-1 : mu; -} - -static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_iq4_nl * dsti = (block_iq4_nl *) cdsti; - - float amax = 0.0f; - float vmax = 0.0f; - - for (int j = 0; j < QK4_NL; ++j) { - const float v = xi[j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - vmax = v; - } - } - - float d = vmax / kvalues_iq4nl[0]; - const float id = d ? 1.0f/d : 0.0f; - - float sumqx = 0, sumq2 = 0; - for (int j = 0; j < QK4_NL/2; ++j) { - const float x0 = xi[0 + j]*id; - const float x1 = xi[QK4_NL/2 + j]*id; - const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0); - const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1); - dsti->qs[j] = xi0 | (xi1 << 4); - const float v0 = kvalues_iq4nl[xi0]; - const float v1 = kvalues_iq4nl[xi1]; - const float w0 = xi[0 + j]*xi[0 + j]; - const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j]; - sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j]; - sumq2 += w0*v0*v0 + w1*v1*v1; - } - - dsti->d = sumq2 > 0 ? sumqx/sumq2 : d; -} - template static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 778d5a48bd9..50a977c3076 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3226,8 +3226,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } break; case GGML_OP_SET_ROWS: { -#pragma message("TODO: implement Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") - return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16) && + return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 || + op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 || + op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64; } break; diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 58cee924401..560604d095f 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -1,4 +1,5 @@ #include "set-rows.cuh" +#include "cpy-utils.cuh" typedef void (*set_rows_kernel_t)(const char * src, char * dst); @@ -10,17 +11,93 @@ __device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) { template<> __device__ __forceinline__ void set_rows_1(const float * src_f, half * dst_h) { - *dst_h = __float2half(*src_f); + convert_f32_f16(src_f, dst_h); } template<> __device__ __forceinline__ void set_rows_1(const float * src_f, nv_bfloat16 * dst_b) { - *dst_b = *src_f; + convert_f32_bf16(src_f, dst_b); } template<> __device__ __forceinline__ void set_rows_1(const float * src_f, float * dst_f) { - *dst_f = *src_f; + convert_f32_f32(src_f, dst_f); +} + +// Generic quantized set_rows kernel template +template +static __global__ void k_set_rows_quant( + const float * __restrict__ src0, const int64_t * __restrict__ src1, block_type * __restrict__ dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t s10, const int64_t s11, const int64_t s12, + const int64_t s1, const int64_t s2, const int64_t s3) { + + const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; + const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk; + + if (i >= ne_total) { + return; + } + + const int64_t i_base = i * qk; + const int64_t i03 = i_base / (ne00 * ne01 * ne02); + const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00; + const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00; + + const int64_t i12 = i03 % ne12; + const int64_t i11 = i02 % ne11; + const int64_t i10 = i01; + + const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + + const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; + block_type * dst_row_ptr = dst + (dst_row*s1 + i02*s2 + i03*s3) / sizeof(block_type); + + const float * src_block = src0_row + i00; + block_type * dst_block = dst_row_ptr + i00 / qk; + + quantize_func(src_block, dst_block); +} + +// Template dispatch function for quantized set_rows +template +static void set_rows_cuda_quant( + const float * src0_d, const int64_t * src1_d, block_type * dst_d, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const size_t nb01, const size_t nb02, const size_t nb03, + const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + cudaStream_t stream) { + + GGML_ASSERT(ne00 % qk == 0); + const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk; + const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE; + const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE); + const dim3 grid_size(num_blocks); + + const int64_t s01 = nb01/sizeof(float); + const int64_t s02 = nb02/sizeof(float); + const int64_t s03 = nb03/sizeof(float); + const int64_t s10 = nb10/sizeof(int64_t); + const int64_t s11 = nb11/sizeof(int64_t); + const int64_t s12 = nb12/sizeof(int64_t); + const int64_t s1 = nb1; + const int64_t s2 = nb2; + const int64_t s3 = nb3; + + if (ne_total > 0) { + k_set_rows_quant<<>>( + src0_d, src1_d, dst_d, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + s01, s02, s03, + s10, s11, s12, + s1, s2, s3); + } } template @@ -145,7 +222,67 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { nb1, nb2, nb3, stream ); + } else if (dst->type == GGML_TYPE_Q4_0) { + set_rows_cuda_quant( + src0_d, src1_d, (block_q4_0*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q4_1) { + set_rows_cuda_quant( + src0_d, src1_d, (block_q4_1*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q5_0) { + set_rows_cuda_quant( + src0_d, src1_d, (block_q5_0*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q5_1) { + set_rows_cuda_quant( + src0_d, src1_d, (block_q5_1*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q8_0) { + set_rows_cuda_quant( + src0_d, src1_d, (block_q8_0*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_IQ4_NL) { + set_rows_cuda_quant( + src0_d, src1_d, (block_iq4_nl*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); } else { - GGML_ABORT("unsupported type"); + GGML_ABORT("unsupported type %s", ggml_type_name(dst->type)); } } From c0f6c31a7fdf587c7ff479ac8c626b14b522ddb9 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Fri, 18 Jul 2025 13:35:32 +0200 Subject: [PATCH 18/20] cuda : Fix Gemma3n not executed as CUDA_GRAPH on NVGPUs (llama/14741) * Fix Gemma3n not executed as CUDA_GRAPH on NVGPUs Gemma3n uses Matrix-Matrix addition as part of their input processing, wrongly triggering CUDA_GRAPH disablement on NVGPUs even when batch-size of 1 is used. * Exclude `project_per_layer_input` by matching node names This ensures that all other graphs which don't exhibit this pattern do not have their behavior changed. * Revert unnecessary formatting changes --- ggml/src/ggml-cuda/ggml-cuda.cu | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 50a977c3076..dfc50ef0daf 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2590,6 +2590,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud // Loop over nodes in GGML graph to obtain info needed for CUDA graph cuda_ctx->cuda_graph->cpy_dest_ptrs.clear(); + const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; + const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj"; + for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2611,9 +2614,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud #endif } - if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) { - // disable CUDA graphs for batch size > 1 for now. - // Changes in batch size or context size can cause changes to the grid size of some kernels. + if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true)) { + // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation + // by means of matching node names. See + // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and + // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773, + // Generally, changes in batch size or context size can cause changes to the grid size of some kernels. use_cuda_graph = false; #ifndef NDEBUG GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); From f476d9bac3c0f17c466e8441b9d0b26d88b28032 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 18 Jul 2025 20:37:26 +0300 Subject: [PATCH 19/20] metal : fuse add, mul + add tests (llama/14596) ggml-ci --- ggml/src/ggml-alloc.c | 15 -- ggml/src/ggml-backend.cpp | 15 -- ggml/src/ggml-impl.h | 16 ++ ggml/src/ggml-metal/ggml-metal-impl.h | 15 +- ggml/src/ggml-metal/ggml-metal.m | 364 +++++++++++++++++++++----- ggml/src/ggml-metal/ggml-metal.metal | 236 ++++++++++++++--- 6 files changed, 518 insertions(+), 143 deletions(-) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 5fd379f6a94..fcc552da519 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -22,21 +22,6 @@ static bool ggml_is_view(const struct ggml_tensor * t) { return t->view_src != NULL; } -static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { - if (a->type != b->type) { - return false; - } - for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (a->ne[i] != b->ne[i]) { - return false; - } - if (a->nb[i] != b->nb[i]) { - return false; - } - } - return true; -} - // ops that return true for this function must not use restrict pointers for their backend implementations static bool ggml_op_can_inplace(enum ggml_op op) { switch (op) { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 788861a365f..b7498b8d402 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -352,21 +352,6 @@ ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) { // backend copy -static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { - if (a->type != b->type) { - return false; - } - for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (a->ne[i] != b->ne[i]) { - return false; - } - if (a->nb[i] != b->nb[i]) { - return false; - } - } - return true; -} - void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 4972558c98b..a2e30994c46 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -73,6 +73,22 @@ static inline int ggml_up(int n, int m) { return (n + m - 1) & ~(m - 1); } +// TODO: move to ggml.h? +static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { + if (a->type != b->type) { + return false; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (a->ne[i] != b->ne[i]) { + return false; + } + if (a->nb[i] != b->nb[i]) { + return false; + } + } + return true; +} + // // logging // diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 752d55c2166..b7b3fc49af3 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -126,6 +126,7 @@ typedef struct { uint64_t nb2; uint64_t nb3; uint64_t offs; + uint64_t o1[8]; } ggml_metal_kargs_bin; typedef struct { @@ -240,7 +241,7 @@ typedef struct { float max_bias; float m0; float m1; - uint16_t n_head_log2; + int32_t n_head_log2; float logit_softcap; } ggml_metal_kargs_flash_attn_ext; @@ -377,8 +378,16 @@ typedef struct { typedef struct { int32_t ne00; int32_t ne00_4; - uint64_t nb01; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; float eps; + int32_t nef1[3]; + int32_t nef2[3]; + int32_t nef3[3]; + uint64_t nbf1[3]; + uint64_t nbf2[3]; + uint64_t nbf3[3]; } ggml_metal_kargs_rms_norm; typedef struct { @@ -484,7 +493,7 @@ typedef struct { float max_bias; float m0; float m1; - uint32_t n_head_log2; + int32_t n_head_log2; } ggml_metal_kargs_soft_max; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 44ddc69d08f..dc391a0d4d5 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -55,6 +55,12 @@ bool has_residency_sets; bool has_bfloat; bool use_bfloat; + bool use_fusion; + + int debug_fusion; + + // how many times a given op was fused + uint64_t fuse_cnt[GGML_OP_COUNT]; size_t max_size; @@ -69,6 +75,9 @@ /*.has_residency_sets =*/ false, /*.has_bfloat =*/ false, /*.use_bfloat =*/ false, + /*.use_fusion =*/ true, + /*.debug_fusion =*/ 0, + /*.fuse_cnt =*/ { 0 }, /*.max_size =*/ 0, /*.name =*/ "", }; @@ -83,16 +92,14 @@ if (ctx->mtl_device == nil) { ctx->mtl_device = MTLCreateSystemDefaultDevice(); - } - if (ctx->mtl_device) { ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; #if defined(GGML_METAL_HAS_RESIDENCY_SETS) - ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL; + ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil; #endif ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; @@ -103,6 +110,14 @@ #else ctx->use_bfloat = false; #endif + ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil; + + { + const char * val = getenv("GGML_METAL_FUSION_DEBUG"); + ctx->debug_fusion = val ? atoi(val) : 0; + } + + memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt)); ctx->max_size = ctx->mtl_device.maxBufferLength; @@ -122,6 +137,18 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte ctx->mtl_device_ref_count--; if (ctx->mtl_device_ref_count == 0) { + if (ctx->debug_fusion > 0) { + fprintf(stderr, "%s: fusion stats:\n", __func__); + for (int i = 0; i < GGML_OP_COUNT; i++) { + if (ctx->fuse_cnt[i] == 0) { + continue; + } + + // note: cannot use ggml_log here + fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]); + } + } + if (ctx->mtl_lock) { [ctx->mtl_lock release]; ctx->mtl_lock = nil; @@ -147,13 +174,27 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ADD, - GGML_METAL_KERNEL_TYPE_ADD_ROW, + GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, + GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, + GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, + GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, + GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, + GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, + GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, + GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, + GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, + GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, + GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, + GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, + GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, + GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, + GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, GGML_METAL_KERNEL_TYPE_SUB, - GGML_METAL_KERNEL_TYPE_SUB_ROW, + GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, GGML_METAL_KERNEL_TYPE_MUL, - GGML_METAL_KERNEL_TYPE_MUL_ROW, + GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, GGML_METAL_KERNEL_TYPE_DIV, - GGML_METAL_KERNEL_TYPE_DIV_ROW, + GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, GGML_METAL_KERNEL_TYPE_REPEAT_F32, GGML_METAL_KERNEL_TYPE_REPEAT_F16, GGML_METAL_KERNEL_TYPE_REPEAT_I32, @@ -218,6 +259,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, GGML_METAL_KERNEL_TYPE_RMS_NORM, + GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, + GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, GGML_METAL_KERNEL_TYPE_L2_NORM, GGML_METAL_KERNEL_TYPE_GROUP_NORM, GGML_METAL_KERNEL_TYPE_NORM, @@ -1135,13 +1178,27 @@ @implementation GGMLMetalClass // simd_sum and simd_max requires MTLGPUFamilyApple7 GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true); @@ -1206,6 +1263,8 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); @@ -1893,7 +1952,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex } } -static bool ggml_metal_encode_node( +static int ggml_metal_encode_node( ggml_backend_t backend, int idx, id encoder, @@ -1903,7 +1962,10 @@ static bool ggml_metal_encode_node( struct ggml_cgraph * gf = ctx->gf; - struct ggml_tensor * node = ggml_graph_node(gf, idx); + enum ggml_op ops[8]; + + struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx; + struct ggml_tensor * node = nodes[0]; //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op)); @@ -1913,7 +1975,7 @@ static bool ggml_metal_encode_node( struct ggml_tensor * dst = node; if (ggml_is_empty(dst)) { - return true; + return 1; } switch (dst->op) { @@ -1924,7 +1986,7 @@ static bool ggml_metal_encode_node( case GGML_OP_PERMUTE: { // noop -> next node - } return true; + } return 1; default: { } break; @@ -1991,6 +2053,8 @@ static bool ggml_metal_encode_node( id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; + int n_fuse = 1; + #if 0 GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); if (src0) { @@ -2062,37 +2126,15 @@ static bool ggml_metal_encode_node( GGML_ASSERT(src0t == GGML_TYPE_F32); GGML_ASSERT(src1t == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous_rows(src0)); + GGML_ASSERT(ggml_is_contiguous_rows(src1)); + const size_t offs = 0; bool bcast_row = false; id pipeline = nil; - if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(src0)); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; - case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - bcast_row = true; - } else { - switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; - case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; - default: GGML_ABORT("fatal error"); - } - } - ggml_metal_kargs_bin args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, @@ -2119,12 +2161,117 @@ static bool ggml_metal_encode_node( /*.nb2 =*/ nb2, /*.nb3 =*/ nb3, /*.offs =*/ offs, + /*.o1 =*/ { offs_src1 }, }; + // c[0] = add(a, b[0]) + // c[1] = add(c[0], b[1]) + // c[2] = add(c[1], b[2]) + // ... + if (ctx_dev->use_fusion) { + ops[0] = GGML_OP_ADD; + ops[1] = GGML_OP_ADD; + ops[2] = GGML_OP_ADD; + ops[3] = GGML_OP_ADD; + ops[4] = GGML_OP_ADD; + ops[5] = GGML_OP_ADD; + ops[6] = GGML_OP_ADD; + ops[7] = GGML_OP_ADD; + + size_t offs_fuse; + id id_fuse; + + for (n_fuse = 0; n_fuse <= 6; ++n_fuse) { + if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) { + break; + } + + if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) { + break; + } + + // b[0] === b[1] === ... + if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) { + break; + } + + // only fuse nodes if src1 is in the same Metal buffer + id_fuse = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse); + if (id_fuse != id_src1) { + break; + } + + ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++; + + args.o1[n_fuse + 1] = offs_fuse; + } + + ++n_fuse; + + if (ctx_dev->debug_fusion > 1 && n_fuse > 1) { + GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse); + } + } + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + switch (dst->op) { + case GGML_OP_ADD: + { + switch (n_fuse) { + case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break; + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break; + case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break; + case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break; + case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break; + default: GGML_ABORT("fatal error"); + } + } break; + case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break; + default: GGML_ABORT("fatal error"); + } + + bcast_row = true; + } else { + switch (dst->op) { + case GGML_OP_ADD: + { + switch (n_fuse) { + case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break; + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break; + case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break; + case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break; + case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break; + default: GGML_ABORT("fatal error"); + } + } break; + case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; + default: GGML_ABORT("fatal error"); + } + } + + if (n_fuse > 1) { + id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst); + } + [encoder setComputePipelineState:pipeline]; [encoder setBytes:&args length:sizeof(args) atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src1 offset:0 atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; if (bcast_row) { @@ -2132,7 +2279,11 @@ static bool ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } else { - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + int nth = 32; + + while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } @@ -2257,12 +2408,13 @@ static bool ggml_metal_encode_node( /*.nb2 =*/ pnb2, /*.nb3 =*/ pnb3, /*.offs =*/ offs, + /*.o1 =*/ { offs_src1}, }; [encoder setComputePipelineState:pipeline]; [encoder setBytes:&args length:sizeof(args) atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src1 offset:0 atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); @@ -2764,7 +2916,7 @@ static bool ggml_metal_encode_node( id h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0)); if (!h_src0) { GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0)); - return false; + return 0; } offs_src0 = 0; @@ -3640,7 +3792,7 @@ static bool ggml_metal_encode_node( id h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1); if (!h_src1) { GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1); - return false; + return 0; } const int64_t neh0 = ne0; @@ -3656,7 +3808,7 @@ static bool ggml_metal_encode_node( id h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst); if (!h_dst) { GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst); - return false; + return 0; } // tokens per expert @@ -3664,7 +3816,7 @@ static bool ggml_metal_encode_node( id h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe); if (!h_tpe) { GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe); - return false; + return 0; } // id map @@ -3673,7 +3825,7 @@ static bool ggml_metal_encode_node( id h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids); if (!h_ids) { GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids); - return false; + return 0; } { @@ -4105,12 +4257,95 @@ static bool ggml_metal_encode_node( case GGML_OP_RMS_NORM: { GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_rows(src0)); float eps; memcpy(&eps, dst->op_params, sizeof(float)); - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + ggml_metal_kargs_rms_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.eps =*/ eps, + /*.nef1 =*/ { ne01 }, + /*.nef2 =*/ { ne02 }, + /*.nef3 =*/ { ne03 }, + /*.nbf1 =*/ { nb01 }, + /*.nbf2 =*/ { nb02 }, + /*.nbf3 =*/ { nb03 }, + }; + + size_t offs_fuse[2] = { 0, 0 }; + id id_fuse[2] = { id_src0, id_src0 }; + + // d[0] = rms_norm(a) + // d[1] = mul(d[0], b) + // d[2] = add(d[1], c) + if (ctx_dev->use_fusion) { + ops[0] = GGML_OP_RMS_NORM; + ops[1] = GGML_OP_MUL; + ops[2] = GGML_OP_ADD; + + for (n_fuse = 0; n_fuse <= 1; ++n_fuse) { + if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) { + break; + } + + if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) { + break; + } + + if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) { + break; + } + + if (!ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) { + break; + } + + if (nodes[n_fuse + 1]->type != GGML_TYPE_F32) { + break; + } + + ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++; + + id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]); + + args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1]; + args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2]; + args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3]; + + args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1]; + args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2]; + args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3]; + } + + ++n_fuse; + + if (ctx_dev->debug_fusion > 1 && n_fuse > 1) { + if (n_fuse == 2) { + GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__); + } + if (n_fuse == 3) { + GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__); + } + } + } + + if (n_fuse > 1) { + id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst); + } + + id pipeline; + + switch (n_fuse) { + case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break; + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break; + default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse); + } int nth = 32; // SIMD width @@ -4121,23 +4356,16 @@ static bool ggml_metal_encode_node( nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); nth = MIN(nth, ne00/4); - ggml_metal_kargs_rms_norm args = { - /*.ne00 =*/ ne00, - /*.ne00_4 =*/ ne00/4, - /*.nb01 =*/ nb01, - /*.eps =*/ eps, - }; - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2]; + [encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_L2_NORM: { @@ -5532,7 +5760,7 @@ static bool ggml_metal_encode_node( } } - return true; + return n_fuse; } static enum ggml_status ggml_metal_graph_compute( @@ -6038,20 +6266,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool; ggml_metal_mem_pool_reset(mem_pool); - for (int idx = node_start; idx < node_end; ++idx) { + for (int idx = node_start; idx < node_end;) { if (should_capture) { [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; } - const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool); + const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool); if (should_capture) { [encoder popDebugGroup]; } - if (!res) { + if (res == 0) { break; } + + idx += res; } [encoder endEncoding]; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 13235e28852..f62b9ad548e 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -832,7 +832,8 @@ enum ggml_sort_order { // general-purpose kernel for addition, subtraction, multiplication and division of two tensors // pros: works for non-contiguous tensors, supports broadcast across all dims // cons: not very efficient -kernel void kernel_add( +template +kernel void kernel_add_fuse_impl( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -848,16 +849,39 @@ kernel void kernel_add( const int i12 = i02%args.ne12; const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); + device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); + + device const float * src1_ptr[F]; + for (short j = 0; j < F; ++j) { + src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); + } for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10)); + + float res = src0_ptr[i0]; + +#pragma unroll + for (short j = 0; j < F; ++j) { + res += src1_ptr[j][i10]; + } + + dst_ptr[i0] = res; } } +typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; + +template [[host_name("kernel_add")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; +template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; +template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; +template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; +template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>; +template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>; +template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; +template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; + kernel void kernel_sub( constant ggml_metal_kargs_bin & args, device const char * src0, @@ -875,7 +899,7 @@ kernel void kernel_sub( const int i11 = i01%args.ne11; device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { @@ -900,9 +924,9 @@ kernel void kernel_mul( const int i12 = i02%args.ne12; const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { const int i10 = i0%args.ne10; @@ -926,9 +950,9 @@ kernel void kernel_div( const int i12 = i02%args.ne12; const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { const int i10 = i0%args.ne10; @@ -970,46 +994,145 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat // assumption: src1 is a row // broadcast src1 into src0 -kernel void kernel_add_row( +template +kernel void kernel_add_row_c4_fuse_impl( constant ggml_metal_kargs_bin & args, - device const float4 * src0, - device const float4 * src1, - device float4 * dst, + device const char * src0, + device const char * src1, + device char * dst, uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; - dst[tpig] = src0[tpig] + src1[tpig % nb]; + const uint i = tpig % nb; + + device const float4 * src0_row = (device const float4 *) (src0); + device float4 * dst_row = (device float4 *) (dst); + + device const float4 * src1_row[F]; + for (short j = 0; j < F; ++j) { + src1_row[j] = (device const float4 *) (src1 + args.o1[j]); + } + + float4 res = src0_row[tpig]; + +#pragma unroll(F) + for (short j = 0; j < F; ++j) { + res += src1_row[j][i]; + } + + dst_row[tpig] = res; } -kernel void kernel_sub_row( +typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t; + +template [[host_name("kernel_add_row_c4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; +template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>; +template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>; +template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>; +template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>; +template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>; +template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>; +template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>; + +template +kernel void kernel_sub_row_c4_fuse_impl( constant ggml_metal_kargs_bin & args, - device const float4 * src0, - device const float4 * src1, - device float4 * dst, + device const char * src0, + device const char * src1, + device char * dst, uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; - dst[tpig] = src0[tpig] - src1[tpig % nb]; + const uint i = tpig % nb; + + device const float4 * src0_row = (device const float4 *) (src0); + device float4 * dst_row = (device float4 *) (dst); + + device const float4 * src1_row[F]; + for (short j = 0; j < F; ++j) { + src1_row[j] = (device const float4 *) (src1 + args.o1[j]); + } + + float4 res = src0_row[tpig]; + +#pragma unroll(F) + for (short j = 0; j < F; ++j) { + res -= src1_row[j][i]; + } + + dst_row[tpig] = res; } -kernel void kernel_mul_row( +typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t; + +template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; + +template +kernel void kernel_mul_row_c4_fuse_impl( constant ggml_metal_kargs_bin & args, - device const float4 * src0, - device const float4 * src1, - device float4 * dst, + device const char * src0, + device const char * src1, + device char * dst, uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; - dst[tpig] = src0[tpig] * src1[tpig % nb]; + const uint i = tpig % nb; + + device const float4 * src0_row = (device const float4 *) (src0); + device float4 * dst_row = (device float4 *) (dst); + + device const float4 * src1_row[F]; + for (short j = 0; j < F; ++j) { + src1_row[j] = (device const float4 *) (src1 + args.o1[j]); + } + + float4 res = src0_row[tpig]; + +#pragma unroll(F) + for (short j = 0; j < F; ++j) { + res *= src1_row[j][i]; + } + + dst_row[tpig] = res; } -kernel void kernel_div_row( +typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t; + +template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; + +template +kernel void kernel_div_row_c4_fuse_impl( constant ggml_metal_kargs_bin & args, - device const float4 * src0, - device const float4 * src1, - device float4 * dst, + device const char * src0, + device const char * src1, + device char * dst, uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; - dst[tpig] = src0[tpig] / src1[tpig % nb]; + const uint i = tpig % nb; + + device const float4 * src0_row = (device const float4 *) (src0); + device float4 * dst_row = (device float4 *) (dst); + + device const float4 * src1_row[F]; + for (short j = 0; j < F; ++j) { + src1_row[j] = (device const float4 *) (src1 + args.o1[j]); + } + + float4 res = src0_row[tpig]; + +#pragma unroll(F) + for (short j = 0; j < F; ++j) { + res /= src1_row[j][i]; + } + + dst_row[tpig] = res; } +typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t; + +template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; + kernel void kernel_scale( device const float * src0, device float * dst, @@ -2116,26 +2239,39 @@ kernel void kernel_norm( } } -kernel void kernel_rms_norm( +// F == 1 : rms_norm (no fuse) +// F == 2 : rms_norm + mul +// F == 3 : rms_norm + mul + add +template +kernel void kernel_rms_norm_fuse_impl( constant ggml_metal_kargs_rms_norm & args, device const char * src0, + device const char * src1_0, + device const char * src1_1, device char * dst, threadgroup float * shmem_f32 [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - ushort tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { if (sgitg == 0) { shmem_f32[tiisg] = 0.0f; } - device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + const int i01 = tgpig.x; + const int i02 = tgpig.y; + const int i03 = tgpig.z; + + device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]); + + device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]); + device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]); float sumf = 0.0f; // parallel sum - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) { sumf += dot(x[i00], x[i00]); } sumf = simd_sum(sumf); @@ -2154,12 +2290,26 @@ kernel void kernel_rms_norm( const float mean = sumf/args.ne00; const float scale = 1.0f/sqrt(mean + args.eps); - device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { - y[i00] = x[i00] * scale; + device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); + for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) { + if (F == 1) { + y[i00] = (x[i00]*scale); + } + if (F == 2) { + y[i00] = (x[i00]*scale)*f0[i00]; + } + if (F == 3) { + y[i00] = (x[i00]*scale)*f0[i00] + f1[i00]; + } } } +typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t; + +template [[host_name("kernel_rms_norm")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>; +template [[host_name("kernel_rms_norm_mul")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>; +template [[host_name("kernel_rms_norm_mul_add")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>; + kernel void kernel_l2_norm( constant ggml_metal_kargs_l2_norm & args, device const char * src0, From 74b5c27837ad21fe4cd6657838f51fd498a7173b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 19 Jul 2025 17:48:07 +0300 Subject: [PATCH 20/20] sync : ggml ggml-ci --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index ca009adb83b..9b223827afb 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -d62df60a07ba3deeb85e5cfc9b1ee07645ff35e2 +a0361ace408ba2c30820deb39e793ad9ed787a85