From e407bccae7ef583001c4348673b1ae2b8129b5bb Mon Sep 17 00:00:00 2001 From: Patrick Peng Date: Thu, 6 Feb 2025 09:29:13 -0500 Subject: [PATCH 01/58] rpc: fix known RCE in rpc-server (ggml/1103) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add bounds checking in `rpc_server::copy_tensor` to prevent out-of-bounds writes + Check if `(uint8_t *)dst->data + ggml_nbytes(src)` remains within the destination buffer’s allocated region. --- ggml/src/ggml-rpc/ggml-rpc.cpp | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 3d0c465780a..97873acc77d 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1045,7 +1045,28 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co ggml_free(ctx); return false; } - GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer); + + uint64_t src_size = (uint64_t) ggml_nbytes(src); + uint64_t dst_data = (uint64_t) dst->data; + uint64_t dst_base = (uint64_t) ggml_backend_buffer_get_base(dst->buffer); + uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer); + + if (dst_data + src_size > dst_base + dst_buf_sz) { + GGML_PRINT_DEBUG("[%s] out-of-bounds write in rpc_server::copy_tensor:\n" + " write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n" + " buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n", + __func__, + dst_data, + dst_data + src_size, + dst_base, + dst_base + dst_buf_sz); + ggml_free(ctx); + return false; + } + + GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", + __func__, (void*) src->buffer, (void*) dst->buffer); + response.result = ggml_backend_buffer_copy_tensor(src, dst); ggml_free(ctx); return true; From 982535fb0d37105ba42bbb7ed82cc1378fbdc1bd Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Tue, 4 Feb 2025 19:07:18 +0800 Subject: [PATCH 02/58] metal : use residency set for other platforms (llama/11648) --- ggml/src/ggml-metal/ggml-metal.m | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 76f8e429178..9605914ffa4 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -20,7 +20,10 @@ #define GGML_METAL_MAX_COMMAND_BUFFERS 8 // create residency sets only on macOS >= 15.0 -#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 +#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \ + TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \ + TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \ + TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000 #define GGML_METAL_HAS_RESIDENCY_SETS 1 #endif @@ -1071,7 +1074,7 @@ static bool ggml_backend_metal_buffer_rset_init( } #if defined(GGML_METAL_HAS_RESIDENCY_SETS) - if (@available(macOS 15.0, *)) { + if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init]; desc.label = @"ggml_backend_metal"; desc.initialCapacity = ctx->n_buffers; @@ -1106,7 +1109,7 @@ static bool ggml_backend_metal_buffer_rset_init( // rset free static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer_context * ctx) { #if defined(GGML_METAL_HAS_RESIDENCY_SETS) - if (@available(macOS 15.0, *)) { + if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { if (ctx->rset) { [ctx->rset endResidency]; [ctx->rset removeAllAllocations]; From 8440a75da20d2fec73c5ae7bb360e6e0adbc0d01 Mon Sep 17 00:00:00 2001 From: fxzjshm <11426482+fxzjshm@users.noreply.github.com> Date: Wed, 5 Feb 2025 02:18:38 +0800 Subject: [PATCH 03/58] HIP: force max threads per block to be 1024 (llama/11621) Some old/vendor forked version of llvm still use 256. Explicitly set it to 1024 to align with upstream llvm. Signed-off-by: fxzjshm --- ggml/src/ggml-hip/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index eb03e10fa48..f4a4683639f 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -46,6 +46,9 @@ endif() message(STATUS "HIP and hipBLAS found") +# Workaround old compilers +set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024") + file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") From 53ad347c0a0370e1c48243328773285d6355f4e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 4 Feb 2025 22:21:42 +0100 Subject: [PATCH 04/58] CUDA: non-contiguous (RMS) norm support (llama/11659) * CUDA: non-contiguous (RMS) norm support --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-cuda/ggml-cuda.cu | 4 ++ ggml/src/ggml-cuda/norm.cu | 89 ++++++++++++++++++---------- ggml/src/ggml-metal/ggml-metal.m | 5 +- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 + 4 files changed, 66 insertions(+), 34 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index bda10aec118..70a5980998c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -38,6 +38,7 @@ #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv6.cuh" #include "ggml-cuda/gla.cuh" +#include "ggml.h" #include #include @@ -3139,6 +3140,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: + return true; case GGML_OP_RMS_NORM_BACK: return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0; break; @@ -3181,7 +3183,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: + return true; case GGML_OP_GROUP_NORM: + return ggml_is_contiguous(op->src[0]); case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_ARANGE: diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index d991ec97281..f127616edda 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -1,12 +1,20 @@ #include "norm.cuh" +#include template -static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) { - const int row = blockIdx.x*blockDim.y + threadIdx.y; - const int tid = threadIdx.x; +static __global__ void norm_f32( + const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps) { + const int nrows = gridDim.x; + const int nchannels = gridDim.y; - x += int64_t(row)*ncols; - dst += int64_t(row)*ncols; + const int row = blockIdx.x; + const int channel = blockIdx.y; + const int sample = blockIdx.z; + const int tid = threadIdx.x; + + x += sample*stride_sample + channel*stride_channel + row*stride_row; + dst += ((sample*nchannels + channel)*nrows + row)*ncols; float2 mean_var = make_float2(0.0f, 0.0f); @@ -97,12 +105,19 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr } template -static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { - const int row = blockIdx.x*blockDim.y + threadIdx.y; - const int tid = threadIdx.x; +static __global__ void rms_norm_f32( + const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps) { + const int nrows = gridDim.x; + const int nchannels = gridDim.y; + + const int row = blockIdx.x; + const int channel = blockIdx.y; + const int sample = blockIdx.z; + const int tid = threadIdx.x; - x += int64_t(row)*ncols; - dst += int64_t(row)*ncols; + x += sample*stride_sample + channel*stride_channel + row*stride_row; + dst += ((sample*nchannels + channel)*nrows + row)*ncols; float tmp = 0.0f; // partial sum for thread in warp @@ -186,13 +201,16 @@ static __global__ void rms_norm_back_f32( } } -static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { +static void norm_f32_cuda( + const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - norm_f32<<>>(x, dst, ncols, eps); + norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - norm_f32<1024><<>>(x, dst, ncols, eps); + norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } @@ -207,13 +225,16 @@ static void group_norm_f32_cuda( } } -static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { +static void rms_norm_f32_cuda( + const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols, eps); + rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024><<>>(x, dst, ncols, eps); + rms_norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } @@ -229,23 +250,26 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); + GGML_TENSOR_UNARY_OP_LOCALS; float eps; memcpy(&eps, dst->op_params, sizeof(float)); GGML_ASSERT(eps >= 0.0f); - norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; + + norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream); } void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -254,8 +278,6 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -271,23 +293,26 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); + GGML_TENSOR_UNARY_OP_LOCALS; float eps; memcpy(&eps, dst->op_params, sizeof(float)); GGML_ASSERT(eps >= 0.0f); - rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; + + rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream); } void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 9605914ffa4..0a264be371e 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1206,10 +1206,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_GROUP_NORM: return has_simdgroup_reduction; case GGML_OP_RMS_NORM: - return has_simdgroup_reduction && (op->ne[0] % 4 == 0); + return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); case GGML_OP_ARGMAX: - case GGML_OP_NORM: return true; + case GGML_OP_NORM: + return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_ROPE: { const int mode = ((const int32_t *) op->op_params)[2]; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 9ca3959abf1..48ac489a655 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8182,9 +8182,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: + return true; case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + return ggml_is_contiguous(op->src[0]); case GGML_OP_ADD: case GGML_OP_ACC: case GGML_OP_MUL: From 88ce1320f5352d887cc0aa99057d25ac3c7a3481 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 5 Feb 2025 08:58:31 +0100 Subject: [PATCH 05/58] CUDA: support for mat. mul. with ne03 != ne13 (llama/11656) --- ggml/src/ggml-cuda/ggml-cuda.cu | 27 +++----- ggml/src/ggml-cuda/mmv.cu | 114 ++++++++++++++++++++------------ 2 files changed, 81 insertions(+), 60 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 70a5980998c..4dbaefdbafd 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1366,8 +1366,6 @@ static void ggml_cuda_op_mul_mat( const int64_t ne13 = src1->ne[3]; const int64_t nrows1 = ggml_nrows(src1); - GGML_ASSERT(ne03 == ne13); - const int64_t ne0 = dst->ne[0]; const int64_t ne1 = dst->ne[1]; @@ -1381,9 +1379,11 @@ static void ggml_cuda_op_mul_mat( GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1)); - GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0); + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); const int64_t i02_divisor = ne12 / ne02; + const int64_t i03_divisor = ne13 / ne03; const size_t src0_ts = ggml_type_size(src0->type); const size_t src0_bs = ggml_blck_size(src0->type); @@ -1399,6 +1399,7 @@ static void ggml_cuda_op_mul_mat( GGML_ASSERT(!(split && ne02 > 1)); GGML_ASSERT(!(split && ne03 > 1)); GGML_ASSERT(!(split && ne02 < ne12)); + GGML_ASSERT(!(split && ne03 < ne13)); ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr; @@ -1562,7 +1563,8 @@ static void ggml_cuda_op_mul_mat( } // for split tensors the data begins at i0 == i0_offset_low - char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; + const size_t nbytes_src0_matrix = ne01*ne00*src0_ts / src0_bs; + char * src0_dd_i = dev[id].src0_dd + ((i03/i03_divisor)*ne02 + (i02/i02_divisor)) * nbytes_src0_matrix; float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10; char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset; float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); @@ -1606,8 +1608,9 @@ static void ggml_cuda_op_mul_mat( CUDA_CHECK(cudaGetLastError()); } - if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) { - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream)); + if (src1_col_0 == 0 && !src0_is_contiguous && i03 % i03_divisor == 0 && i02 % i02_divisor == 0) { + CUDA_CHECK(ggml_cuda_cpy_tensor_2d( + src0_dd_i, src0, i03/i03_divisor, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream)); } // do the computation @@ -1882,7 +1885,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (!split && use_mul_mat_vec && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { + if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { // the custom F16 vector kernel can be used over batched cuBLAS GEMM // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) ggml_cuda_mul_mat_vec(ctx, src0, src1, dst); @@ -2216,12 +2219,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_rms_norm_back(ctx, dst); break; case GGML_OP_MUL_MAT: - if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { - GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]); - return false; - } else { - ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst); - } + ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst); break; case GGML_OP_MUL_MAT_ID: ggml_cuda_mul_mat_id(ctx, dst); @@ -2998,9 +2996,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { return false; } - if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) { - return false; - } #ifdef GGML_USE_MUSA if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) { diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu index 5a9ddd9580a..f89ed03b578 100644 --- a/ggml/src/ggml-cuda/mmv.cu +++ b/ggml/src/ggml-cuda/mmv.cu @@ -1,18 +1,21 @@ +#include "ggml.h" #include "common.cuh" #include "mmv.cuh" template static __global__ void mul_mat_vec( const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, - const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) { + const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) { const int64_t row = blockIdx.x; - const int64_t channel = blockIdx.z; + const int64_t channel = blockIdx.y; + const int64_t sample = blockIdx.z; const int tid = threadIdx.x; constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - x += (channel/channel_ratio)*stride_channel_x + row*stride_row; - y += channel *stride_channel_y; - dst += channel *stride_channel_dst; + x += (sample/sample_ratio)*stride_sample_x + (channel/channel_ratio)*stride_channel_x + row*stride_row; + y += sample *stride_sample_y + channel *stride_channel_y; + dst += sample *stride_sample_dst + channel *stride_channel_dst; const float2 * y2 = (const float2 *) y; @@ -91,12 +94,15 @@ template static void launch_mul_mat_vec_cuda( const T * x, const float * y, float * dst, const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(stride_row % 2 == 0); GGML_ASSERT(nchannels_y % nchannels_x == 0); + GGML_ASSERT(nsamples_y % nsamples_x == 0); const int64_t channel_ratio = nchannels_y / nchannels_x; + const int64_t sample_ratio = nsamples_y / nsamples_x; int device; int warp_size; @@ -118,40 +124,48 @@ static void launch_mul_mat_vec_cuda( } const int smem = warp_size*sizeof(float); - const dim3 block_nums(nrows, 1, nchannels_y); + const dim3 block_nums(nrows, nchannels_y, nsamples_y); const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { case 32: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 64: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 96: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 128: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 160: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 192: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 224: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 256: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; default: { GGML_ABORT("fatal error"); @@ -163,16 +177,19 @@ template static void mul_mat_vec_cuda( const T * x, const float * y, float * dst, const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, enum ggml_prec prec, cudaStream_t stream) { switch (prec) { case GGML_PREC_DEFAULT: { - launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, - stride_channel_x, stride_channel_y, stride_channel_dst, stream); + launch_mul_mat_vec_cuda + (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; case GGML_PREC_F32: { - launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, - stride_channel_x, stride_channel_y, stride_channel_dst, stream); + launch_mul_mat_vec_cuda + (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; } } @@ -181,10 +198,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; + GGML_TENSOR_BINARY_OP_LOCALS; + + const size_t ts_src0 = ggml_type_size(src0->type); + const size_t ts_src1 = ggml_type_size(src1->type); + const size_t ts_dst = ggml_type_size(dst->type); + + GGML_ASSERT(ne11 == 1); + GGML_ASSERT(ne12 == ne2); + GGML_ASSERT(ne13 == ne3); - GGML_ASSERT(src1->ne[1] == 1); + GGML_ASSERT(nb00 == ts_src0); + GGML_ASSERT(nb10 == ts_src1); + GGML_ASSERT(nb0 == ts_dst); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; @@ -192,29 +218,22 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * const float * src1_d = (const float *) src1->data; float * dst_d = (float *) dst->data; - const int64_t ne02 = src0->ne[2]; - const int64_t ne12 = src1->ne[2]; - GGML_ASSERT(dst->ne[2] == ne12); - - GGML_ASSERT(src0->ne[3] == 1); - GGML_ASSERT(src1->ne[3] == 1); - GGML_ASSERT( dst->ne[3] == 1); - - const int64_t stride_row = src0->nb[1] / ggml_type_size(src0->type); - const int64_t channel_stride_x = src0->nb[2] / ggml_type_size(src0->type); - const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type); - const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type); + const int64_t s01 = src0->nb[1] / ts_src0; + const int64_t s02 = src0->nb[2] / ts_src0; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s2 = dst->nb[2] / ts_dst; + const int64_t s03 = src0->nb[3] / ts_src0; + const int64_t s13 = src1->nb[3] / ts_src1; + const int64_t s3 = dst->nb[3] / ts_dst; switch (src0->type) { case GGML_TYPE_F16: { const half * src0_d = (const half *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, - channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); + mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, - channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); + mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream()); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); @@ -243,20 +262,27 @@ void ggml_cuda_op_mul_mat_vec( const int64_t stride_row = ne00; const int64_t nchannels_x = 1; const int64_t nchannels_y = 1; - const int64_t channel_stride_x = 0; - const int64_t channel_stride_y = 0; - const int64_t channel_stride_dst = 0; + const int64_t stride_channel_x = 0; + const int64_t stride_channel_y = 0; + const int64_t stride_channel_dst = 0; + const int64_t nsamples_x = 1; + const int64_t nsamples_y = 1; + const int64_t stride_sample_x = 0; + const int64_t stride_sample_y = 0; + const int64_t stride_sample_dst = 0; switch (src0->type) { case GGML_TYPE_F16: { const half * src0_d = (const half *) src0_dd_i; mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, - nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); + nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, - nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); + nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); From 26ab6ec977883f61b69b0522ae1abbc1c0269c9f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 5 Feb 2025 10:57:42 +0200 Subject: [PATCH 06/58] metal : adjust support conditions for norm operators (llama/11671) cont #11659 ggml-ci --- ggml/src/ggml-metal/ggml-metal.m | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 0a264be371e..c63dbad201b 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1204,13 +1204,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_SUM_ROWS: case GGML_OP_SOFT_MAX: case GGML_OP_GROUP_NORM: - return has_simdgroup_reduction; + return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_RMS_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); case GGML_OP_ARGMAX: return true; case GGML_OP_NORM: - return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); + return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); case GGML_OP_ROPE: { const int mode = ((const int32_t *) op->op_params)[2]; From c2b8bf0bbce4ca4533db73843b71ca33251e7f95 Mon Sep 17 00:00:00 2001 From: Charles Duffy Date: Wed, 5 Feb 2025 19:52:31 -0600 Subject: [PATCH 07/58] metal : avoid breaking build when metal API predates TARGET_OS_VISION (llama/11690) Avoids breakage in nix flake build introduced by b0569130c5e9c671152c913d82803b7c2f014ff9 --- ggml/src/ggml-metal/ggml-metal.m | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index c63dbad201b..944d90af344 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -19,6 +19,10 @@ // max number of MTLCommandBuffer used to submit a graph for processing #define GGML_METAL_MAX_COMMAND_BUFFERS 8 +#ifndef TARGET_OS_VISION +#define TARGET_OS_VISION 0 +#endif + // create residency sets only on macOS >= 15.0 #if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \ TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \ From 34a9e8ad5c74d9783d619630b22cea698f6f0fbf Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 6 Feb 2025 00:02:18 -0600 Subject: [PATCH 08/58] vulkan: use smaller combined allocations to avoid fragmentation (llama/11551) --- ggml/src/ggml-alloc.c | 14 +------------- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 9a3bf9f2923..7244a9cbb06 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -989,19 +989,7 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte this_size = GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment); } - if (this_size > max_size) { - GGML_LOG_ERROR("%s: tensor %s is too large to fit in a %s buffer (tensor size: %zu, max buffer size: %zu)\n", - __func__, t->name, - ggml_backend_buft_name(buft), - this_size, max_size); - for (size_t i = 0; i < n_buffers; i++) { - ggml_backend_buffer_free(buffers[i]); - } - free(buffers); - return NULL; - } - - if ((cur_buf_size + this_size) > max_size) { + if (cur_buf_size > 0 && (cur_buf_size + this_size) > max_size) { // allocate tensors in the current buffer if (!alloc_tensor_range(ctx, first, t, buft, cur_buf_size, &buffers, &n_buffers)) { return NULL; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 48ac489a655..2e1bcf691b3 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -156,6 +156,7 @@ struct vk_device_struct { vk::PhysicalDeviceProperties properties; std::string name; uint64_t max_memory_allocation_size; + uint64_t suballocation_block_size; bool fp16; bool pipeline_robustness; vk::Device device; @@ -2269,6 +2270,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->physical_device.getProperties2(&props2); device->properties = props2.properties; + device->vendor_id = device->properties.vendorID; const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); @@ -2280,7 +2282,20 @@ static vk_device ggml_vk_get_device(size_t idx) { device->max_memory_allocation_size = props3.maxMemoryAllocationSize; } - device->vendor_id = device->properties.vendorID; + const char* GGML_VK_SUBALLOCATION_BLOCK_SIZE = getenv("GGML_VK_SUBALLOCATION_BLOCK_SIZE"); + + if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) { + device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE); +#if defined(_WIN32) + } else if (device->vendor_id == VK_VENDOR_ID_NVIDIA) { + // Limit batching of allocations to 1GB by default to avoid fragmentation issues + device->suballocation_block_size = 1024*1024*1024; +#endif + } else { + device->suballocation_block_size = device->max_memory_allocation_size; + } + device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size); + device->subgroup_size = subgroup_props.subgroupSize; device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; if (sm_builtins) { @@ -7561,7 +7576,7 @@ static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; - return ctx->device->max_memory_allocation_size; + return ctx->device->suballocation_block_size; } static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { From 96bd41f343e8b7ebc95ef4326c575553fc9983eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9my=20O?= Date: Thu, 6 Feb 2025 07:09:59 +0100 Subject: [PATCH 09/58] vulkan: initial support for IQ4_XS quantization (llama/11501) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 25 ++++++++++++ .../vulkan-shaders/copy_from_quant.comp | 2 +- .../vulkan-shaders/copy_to_quant.comp | 2 +- .../vulkan-shaders/dequant_funcs.comp | 38 ++++++++++++++++++- .../vulkan-shaders/dequant_funcs_cm2.comp | 23 +++++++++++ .../vulkan-shaders/dequant_iq4_xs.comp | 34 +++++++++++++++++ .../vulkan-shaders/flash_attn_cm2.comp | 2 +- .../vulkan-shaders/get_rows_quant.comp | 2 +- .../vulkan-shaders/mul_mat_vec.comp | 2 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 21 +++++++++- .../vulkan-shaders/mul_mm_cm2.comp | 2 +- .../src/ggml-vulkan/vulkan-shaders/types.comp | 28 +++++++++++--- .../vulkan-shaders/vulkan-shaders-gen.cpp | 1 + 13 files changed, 169 insertions(+), 13 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 2e1bcf691b3..1c99ebe2e2c 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1622,6 +1622,7 @@ static void ggml_vk_load_shaders(vk_device& device) { //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s) //CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs) //CREATE_FA(GGML_TYPE_IQ3_S, iq3_s) + //CREATE_FA(GGML_TYPE_IQ4_XS, iq4_xs) CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl) #undef CREATE_FA @@ -1655,6 +1656,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) @@ -1673,6 +1675,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) #undef CREATE_MM #undef CREATE_MM2 @@ -1726,6 +1729,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -1744,6 +1748,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } @@ -1770,6 +1775,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } else { CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); @@ -1788,6 +1794,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } } @@ -1837,6 +1844,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. @@ -1861,6 +1869,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } #undef CREATE_MM2 @@ -1902,6 +1911,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. @@ -1926,6 +1936,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } #undef CREATE_MM @@ -1962,6 +1973,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); @@ -1981,6 +1993,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); } @@ -2001,6 +2014,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true); // dequant shaders @@ -2020,6 +2034,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_XXS], "dequant_iq3_xxs", dequant_iq3_xxs_len, dequant_iq3_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); // get_rows @@ -2035,6 +2050,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs", get_rows_iq3_xxs_len, get_rows_iq3_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -2049,6 +2065,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs_f32", get_rows_iq3_xxs_f32_len, get_rows_iq3_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); @@ -2995,6 +3012,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: break; default: @@ -3048,6 +3066,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: break; default: @@ -3084,6 +3103,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: break; default: @@ -3132,6 +3152,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: break; default: @@ -3163,6 +3184,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: break; default: @@ -8037,6 +8059,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: break; default: @@ -8110,6 +8133,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm //case GGML_TYPE_IQ2_S: //case GGML_TYPE_IQ3_XXS: //case GGML_TYPE_IQ3_S: + //case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: break; default: @@ -8132,6 +8156,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: return true; default: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp index aeae5400dfc..9c9fe9626db 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -12,7 +12,7 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; #endif void main() { -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) init_iq_shmem(gl_WorkGroupSize); if (gl_LocalInvocationIndex.x != 0) { return; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index d4b068e6186..660811086d6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -217,7 +217,7 @@ void quantize(uint dst_idx, uint src_idx) #endif void main() { -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) init_iq_shmem(gl_WorkGroupSize); if (gl_LocalInvocationIndex.x != 0) { return; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index ee68775317b..ecfdbfaa88c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -304,6 +304,42 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_IQ4_XS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint iq = 16 * ib32 + (iqs % 16); + + const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3; + const uint qshift = (iqs & 16) >> 2; + u8vec2 qs = u8vec2(data_a[a_offset + ib].qs[iq], data_a[a_offset + ib].qs[iq + 1]); + qs = (qs >> qshift) & uint8_t(0xF); + + const float dl = float(int(sl | (sh << 4)) - 32); + return dl * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint iq = 16 * ib32 + (iqs % 16); + + const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3; + const uint qshift = (iqs & 16) >> 2; + u8vec4 qs = u8vec4( + data_a[a_offset + ib].qs[iq + 0], + data_a[a_offset + ib].qs[iq + 1], + data_a[a_offset + ib].qs[iq + 2], + data_a[a_offset + ib].qs[iq + 3] + ); + qs = (qs >> qshift) & uint8_t(0xF); + + const float dl = float(int(sl | (sh << 4)) - 32); + return dl * vec4( + kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], + kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]); +} +#endif + #if defined(DATA_A_IQ4_NL) vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint vui = uint(data_a[a_offset + ib].qs[iqs]); @@ -321,7 +357,7 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) vec2 get_dm(uint ib, uint a_offset) { return vec2(float(data_a[a_offset + ib].d), 0); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 974efd3f9a6..78c3bddf227 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -454,6 +454,27 @@ float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords } #endif +#if defined(DATA_A_IQ4_XS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_XS { + block_iq4_xs block; +}; + +float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + + const uint sl = (bl.block.scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = ((bl.block.scales_h) >> (2 * ib32)) & 3; + const uint qshift = (idx & 16) >> 2; + const uint q = (bl.block.qs[16 * ib32 + (idx % 16)] >> qshift) & 0xF; + + float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]); + return ret; +} +#endif #if defined(DATA_A_IQ4_NL) layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL { @@ -504,6 +525,8 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor #define dequantFuncA dequantFuncIQ3_XXS #elif defined(DATA_A_IQ3_S) #define dequantFuncA dequantFuncIQ3_S +#elif defined(DATA_A_IQ4_XS) +#define dequantFuncA dequantFuncIQ4_XS #elif defined(DATA_A_IQ4_NL) #define dequantFuncA dequantFuncIQ4_NL #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp new file mode 100644 index 00000000000..f930852a48a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp @@ -0,0 +1,34 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq4_xs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (1 scale and 32 quantized values) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + + const float d = float(data_a[ib].d); + // Scales are 6 bits + const uint scale = ((data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF) + | (((data_a[ib].scales_h >> (2 * ib32)) & 3) << 4); + const float dl = d * (int(scale) - 32); + + const uint b_idx = 256 * ib + 32 * ib32; + const uint q_idx = 16 * ib32; + [[unroll]] for (uint l = 0; l < 16; ++l) { + data_b[b_idx + l + 0] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]); + data_b[b_idx + l + 16] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 043a5302388..ba88ce79a21 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -104,7 +104,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele #endif void main() { -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) init_iq_shmem(gl_WorkGroupSize); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp index 09dc43d8dc3..c16a2a9f605 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -12,7 +12,7 @@ void main() { const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) init_iq_shmem(gl_WorkGroupSize); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 48156e7bab6..d7e99727db1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -133,7 +133,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { void main() { const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) init_iq_shmem(gl_WorkGroupSize); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index d0559aac8ec..33b2234e71d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -95,7 +95,7 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif void main() { -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) init_iq_shmem(gl_WorkGroupSize); #endif @@ -547,6 +547,25 @@ void main() { const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ4_XS) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint iq = 16 * ib32 + 2 * (idx % 8); + + const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3; + const uint qshift = (idx & 8) >> 1; + u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]); + qs = (qs >> qshift) & uint8_t(0xF); + + const float d = float(data_a[ib].d); + const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); + buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); #elif defined(DATA_A_IQ4_NL) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 27c5d68b3d9..7e29bbfec7b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -106,7 +106,7 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem #endif void main() { -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) init_iq_shmem(gl_WorkGroupSize); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index 9e56a35300b..db643a54c8e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -1026,6 +1026,23 @@ void init_iq_shmem(uvec3 wgsize) #define A_TYPE_PACKED16 block_iq3_s_packed16 #endif +#define QUANT_K_IQ4_XS 256 +#define QUANT_R_IQ4_XS 1 + +struct block_iq4_xs +{ + float16_t d; + uint16_t scales_h; + uint8_t scales_l[QUANT_K_IQ4_XS/64]; + uint8_t qs[QUANT_K_IQ4_XS/2]; +}; + +#if defined(DATA_A_IQ4_XS) +#define QUANT_K QUANT_K_IQ4_XS +#define QUANT_R QUANT_R_IQ4_XS +#define A_TYPE block_iq4_xs +#endif + #define QUANT_K_IQ4_NL 32 #define QUANT_R_IQ4_NL 2 @@ -1042,7 +1059,13 @@ struct block_iq4_nl_packed16 }; #if defined(DATA_A_IQ4_NL) +#define QUANT_K QUANT_K_IQ4_NL +#define QUANT_R QUANT_R_IQ4_NL +#define A_TYPE block_iq4_nl +#define A_TYPE_PACKED16 block_iq4_nl_packed16 +#endif +#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) const int8_t kvalues_iq4nl_const[16] = { int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113) @@ -1058,11 +1081,6 @@ void init_iq_shmem(uvec3 wgsize) } barrier(); } - -#define QUANT_K QUANT_K_IQ4_NL -#define QUANT_R QUANT_R_IQ4_NL -#define A_TYPE block_iq4_nl -#define A_TYPE_PACKED16 block_iq4_nl_packed16 #endif #endif // !defined(GGML_TYPES_COMP) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 93ddbfadc5f..77e7e1148b4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -60,6 +60,7 @@ const std::vector type_names = { "iq2_s", "iq3_xxs", "iq3_s", + "iq4_xs", "iq4_nl" }; From 6724a2ab1373203b86bb5948cd91fbfde6ee73a1 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 6 Feb 2025 00:15:30 -0600 Subject: [PATCH 10/58] vulkan: optimize coopmat2 iq2/iq3 callbacks (llama/11521) * vulkan: optimize coopmat2 iq2/iq3 callbacks * build: trigger CI on GLSL compute shader changes --- .../vulkan-shaders/dequant_funcs_cm2.comp | 79 +++++++++---------- 1 file changed, 38 insertions(+), 41 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 78c3bddf227..0eba3742011 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -323,15 +323,16 @@ float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCo const uint8_t qs = bl.block.qs[iqs]; const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3])); - const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t(signscale >> 28)); + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28)); uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7); sign |= bitCount(sign) << 7; - const uint8_t g = unpack8(iq2xxs_grid[qs][(idx & 4) >> 2])[idx & 3]; + uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 g = vec2(unpack8(g2)); - float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); - - return ret; + vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + return float16_t(ret[idx & 1]); } #endif @@ -350,14 +351,16 @@ float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoor const uint iqs = (idx & 0xF8) >> 3; // 0..63 const uint16_t qs = bl.block.qs[iqs]; - const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t((bl.block.scales[is] >> sshift) & 0xF)); + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF)); uint sign = uint(qs >> 9); sign |= bitCount(sign) << 7; - const uint8_t g = unpack8(iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2])[idx & 3]; + uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 g = vec2(unpack8(g2)); - float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); - return ret; + vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + return float16_t(ret[idx & 1]); } #endif @@ -369,24 +372,23 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2 float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { uint idx = coordInBlock[1]; - uint lsb = idx & 1; - idx /= 2; - const uint ib8 = (idx % 128) / 4; // 0..31 - const uint ib32 = ib8 / 4; // 0..7 + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + const uint ib8 = (idx & 0xF8) >> 3; // 0..31 + const uint qhshift = 2 * (ib8 % 4); - const uint scale = (bl.block.scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf; const uint qs = bl.block.qs[ib8]; const uint qh = bl.block.qh[ib32]; - const uint qhshift = 2 * (ib8 % 4); - const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)); + const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (idx & 0x6); const float d = float(bl.block.d); const float db = d * 0.25 * (0.5 + scale); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid)); - return float16_t(v[lsb]); + const ivec2 sign01 = 1 - (2 & ivec2(sign << 1, sign)); + uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 v = db * vec2(sign01) * vec2(unpack8(g2)); + return float16_t(v[idx & 1]); } #endif @@ -401,28 +403,25 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3 float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { + decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl); uint idx = coordInBlock[1]; - uint lsb = idx & 1; - idx /= 2; - const uint iqs = (idx % 128) / 2; // 0..63 - const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values + const uint iqs = (idx & 0xFC) >> 2; // 0..63 + const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);// 8 values const float d = float(bl.block.d); const uint qs = bl.block.qs[iqs]; - const uint signs = pack32(u8vec4( - bl.block.qs[is+0], - bl.block.qs[is+1], - bl.block.qs[is+2], - bl.block.qs[is+3] + const uint signs = pack32(u16vec2( + bl16.block.qs[is/2+0], + bl16.block.qs[is/2+1] )); const float db = d * 0.5 * (0.5 + (signs >> 28)); const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6); + const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign))); + const uint grid = iq3xxs_grid[qs] >> (16 * ((idx & 2) >> 1)); const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); - return float16_t(v[lsb]); + return float16_t(v[idx & 1]); } #endif @@ -434,23 +433,21 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3 float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { uint idx = coordInBlock[1]; - uint lsb = idx & 1; - idx /= 2; - const uint iqs = (idx % 128) / 2; // 0..63 - const uint iqh = iqs / 8; + const uint iqs = (idx & 0xFC) >> 2; // 0..63 + const uint iqh = (idx & 0xE0) >> 5; const float d = float(bl.block.d); const uint qs = bl.block.qs[iqs]; const uint qh = bl.block.qh[iqh]; - const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (2 * (idx % 4))); + const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (idx & 0x6)); const uint scale = bl.block.scales[iqs / 16]; - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); + const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign))); const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); - const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> ((idx & 2) << 3); const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); - return float16_t(v[lsb]); + return float16_t(v[idx & 1]); } #endif From f62a1526a616f689d286064515716950832b22a4 Mon Sep 17 00:00:00 2001 From: junchao-zhao <68935141+junchao-loongson@users.noreply.github.com> Date: Thu, 6 Feb 2025 17:20:00 +0800 Subject: [PATCH 11/58] ggml : fix LoongArch compile error with 128-bit SIMD (llama/11701) --- ggml/src/ggml-cpu/ggml-cpu-quants.c | 169 +++++++++++++++------------- 1 file changed, 91 insertions(+), 78 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 88303ff0e61..72ec58ceef6 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -297,6 +297,90 @@ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 #endif +#if defined(__loongarch_sx) + +static __m128i lsx_packs_w(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_w(a, 15); + tmp1 = __lsx_vsat_w(b, 15); + return __lsx_vpickev_h(tmp1, tmp); +} + +static __m128i lsx_packs_h(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_h(a, 7); + tmp1 = __lsx_vsat_h(b, 7); + return __lsx_vpickev_b(tmp1, tmp); +} + +static __m128i lsx_packus_h(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_hu(a, 7); + tmp1 = __lsx_vsat_hu(b, 7); + return __lsx_vpickev_b(tmp1, tmp); +} + +static __m128i lsx_maddubs_h(__m128i a, __m128i b) { + __m128i tmp1, tmp2; + tmp1 = __lsx_vmulwev_h_b(a, b); + tmp2 = __lsx_vmulwod_h_b(a, b); + return __lsx_vsadd_h(tmp1, tmp2); +} + +static __m128i lsx_madd_h(__m128i a, __m128i b) { + __m128i tmp1, tmp2; + tmp1 = __lsx_vmulwev_w_h(a, b); + tmp2 = __lsx_vmulwod_w_h(a, b); + return __lsx_vadd_w(tmp1, tmp2); +} + +static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) { + v4i32 __ret = {d, c, b, a}; + return (__m128i)__ret; +} + +static __m128i lsx_shuffle_b(__m128i a, __m128i b) { + __m128i mask_f, zero, tmp0, tmp2, mask; + int f = 0x8f; + mask_f = __lsx_vreplgr2vr_b(f); + zero = __lsx_vldi(0); + tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits + tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive + mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask + tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones + return __lsx_vshuf_b(a, zero, tmp2); +} + +static __m128i lsx_hadd_h(__m128i a, __m128i b) { + __m128i tmp1 = __lsx_vpickev_h(b, a); + __m128i tmp2 = __lsx_vpickod_h(b, a); + return __lsx_vadd_h(tmp1, tmp2); +} + +static __m128i lsx_hadd_w(__m128i a, __m128i b) { + __m128i tmp1 = __lsx_vpickev_w(b, a); + __m128i tmp2 = __lsx_vpickod_w(b, a); + return __lsx_vadd_w(tmp1, tmp2); +} + +static __m128 lsx_hadd_s(__m128 a, __m128 b) { + __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a); + __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a); + + return __lsx_vfadd_s(tmp1, tmp2); +} + +static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { + __m128 res_0 =lsx_hadd_s(a, b); + __m128 res_1 =lsx_hadd_s(c, d); + __m128 res =lsx_hadd_s(res_0, res_1); + res =lsx_hadd_s(res, res); + res =lsx_hadd_s(res, res); + + return ((v4f32)res)[0]; +} +#endif + #if defined(__loongarch_asx) #ifdef __clang__ @@ -395,11 +479,6 @@ static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1 return (__m256i)__ret; } -static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) { - v4i32 __ret = {d, c, b, a}; - return (__m128i)__ret; -} - static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) { v4i64 __ret = {d, c, b, a}; return (__m256i)__ret; @@ -409,18 +488,6 @@ static __m256i lasx_insertf128( __m128i x, __m128i y) { return lasx_set_q(x, y); } -static __m128i lsx_shuffle_b(__m128i a, __m128i b) { - __m128i mask_f, zero, tmp0, tmp2, mask; - int f = 0x8f; - mask_f = __lsx_vreplgr2vr_b(f); - zero = __lsx_vldi(0); - tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits - tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive - mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask - tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones - return __lsx_vshuf_b(a, zero, tmp2); -} - static __m256i lasx_shuffle_b(__m256i a, __m256i b) { __m256i mask_f, zero, tmp0, tmp2, mask; int f = 0x8f; @@ -482,25 +549,6 @@ static __m128 lasx_extractf128( __m256 a, int pos) { return ret; } -static __m128i lsx_hadd_h(__m128i a, __m128i b) { - __m128i tmp1 = __lsx_vpickev_h(b, a); - __m128i tmp2 = __lsx_vpickod_h(b, a); - return __lsx_vadd_h(tmp1, tmp2); -} - -static __m128i lsx_hadd_w(__m128i a, __m128i b) { - __m128i tmp1 = __lsx_vpickev_w(b, a); - __m128i tmp2 = __lsx_vpickod_w(b, a); - return __lsx_vadd_w(tmp1, tmp2); -} - -static __m128 lsx_hadd_s(__m128 a, __m128 b) { - __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a); - __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a); - - return __lsx_vfadd_s(tmp1, tmp2); -} - static __m256i lasx_maddubs_h(__m256i a, __m256i b) { __m256i tmp1, tmp2; tmp1 = __lasx_xvmulwev_h_b(a, b); @@ -529,42 +577,6 @@ static __m256i lasx_packs_h(__m256i a, __m256i b) { return __lasx_xvpickev_b(tmp1, tmp); } -static __m128i lsx_packs_w(__m128i a, __m128i b) { - __m128i tmp, tmp1; - tmp = __lsx_vsat_w(a, 15); - tmp1 = __lsx_vsat_w(b, 15); - return __lsx_vpickev_h(tmp1, tmp); -} - -static __m128i lsx_packs_h(__m128i a, __m128i b) { - __m128i tmp, tmp1; - tmp = __lsx_vsat_h(a, 7); - tmp1 = __lsx_vsat_h(b, 7); - return __lsx_vpickev_b(tmp1, tmp); -} - -static __m128i lsx_packus_h(__m128i a, __m128i b) { - __m128i tmp, tmp1; - tmp = __lsx_vsat_hu(a, 7); - tmp1 = __lsx_vsat_hu(b, 7); - return __lsx_vpickev_b(tmp1, tmp); -} - - -static __m128i lsx_maddubs_h(__m128i a, __m128i b) { - __m128i tmp1, tmp2; - tmp1 = __lsx_vmulwev_h_b(a, b); - tmp2 = __lsx_vmulwod_h_b(a, b); - return __lsx_vsadd_h(tmp1, tmp2); -} - -static __m128i lsx_madd_h(__m128i a, __m128i b) { - __m128i tmp1, tmp2; - tmp1 = __lsx_vmulwev_w_h(a, b); - tmp2 = __lsx_vmulwod_w_h(a, b); - return __lsx_vadd_w(tmp1, tmp2); -} - // multiply int8_t, add results pairwise twice static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { // Get absolute values of x vectors @@ -2232,21 +2244,22 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r } sumf = hsum_float_8(acc); + #elif defined(__loongarch_sx) // set constants const __m128i low_mask = __lsx_vreplgr2vr_b(0xF); const __m128i off = __lsx_vreplgr2vr_b(8); // Initialize accumulator with zeros - __m128 acc_0 = __lsx_vldi(0); - __m128 acc_1 = __lsx_vldi(0); - __m128 acc_2 = __lsx_vldi(0); - __m128 acc_3 = __lsx_vldi(0); + __m128 acc_0 = (__m128)__lsx_vldi(0); + __m128 acc_1 = (__m128)__lsx_vldi(0); + __m128 acc_2 = (__m128)__lsx_vldi(0); + __m128 acc_3 = (__m128)__lsx_vldi(0); for (; ib + 1 < nb; ib += 2) { // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); + const __m128 d_0_1 = (__m128)__lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0); @@ -2264,7 +2277,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) ); + const __m128 d_2_3 = (__m128)__lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) ); const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0); From 1f1ddf8160336a7218c09a8c802bf137a7d3d370 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Thu, 6 Feb 2025 17:12:35 +0530 Subject: [PATCH 12/58] SYCL: Adjust support condition for norm operators (llama/11674) SYCL does not support non contiguous tensors for norm operations --- ggml/src/ggml-sycl/ggml-sycl.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 2984ed82e8a..aab34a752d4 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4537,14 +4537,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: - case GGML_OP_NORM: case GGML_OP_ADD: case GGML_OP_ADD1: case GGML_OP_LOG: case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: + return true; + case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_GROUP_NORM: + return ggml_is_contiguous(op->src[0]); case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: @@ -4576,7 +4579,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: - case GGML_OP_GROUP_NORM: case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_LEAKY_RELU: From e4c89e5200ae38bce00cb2518e3fbe2a4cd6c6b7 Mon Sep 17 00:00:00 2001 From: Jinyang He Date: Fri, 7 Feb 2025 15:38:31 +0800 Subject: [PATCH 13/58] ggml : optimize and build warning fix for LoongArch (llama/11709) * ggml : optimize convert f32<->f16 for loongarch_asx * ggml : optimize loongarch_asx extend i16,i8,u8 to i32,i16 * ggml : Fix warnings when run cpu CI locally on LoongArch --- ggml/src/ggml-cpu/ggml-cpu-impl.h | 18 +++++--------- ggml/src/ggml-cpu/ggml-cpu-quants.c | 37 ++++++----------------------- ggml/src/ggml-cpu/ggml-cpu.c | 24 +++++++------------ 3 files changed, 22 insertions(+), 57 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index d71076ad12b..9ddd972a5cf 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -360,21 +360,15 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) #endif #if defined(__loongarch_asx) - -typedef union { - int32_t i; - float f; -} ft_union; - /* float type data load instructions */ -static __m128 __lsx_vreplfr2vr_s(float val) { - ft_union fi_tmpval = {.f = val}; - return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i); +static __m128 __lsx_vreplfr2vr_s(const float val) { + v4f32 res = {val, val, val, val}; + return (__m128)res; } -static __m256 __lasx_xvreplfr2vr_s(float val) { - ft_union fi_tmpval = {.f = val}; - return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i); +static __m256 __lasx_xvreplfr2vr_s(const float val) { + v8f32 res = {val, val, val, val, val, val, val, val}; + return (__m256)res; } #endif diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 72ec58ceef6..27ec1493565 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -501,30 +501,15 @@ static __m256i lasx_shuffle_b(__m256i a, __m256i b) { } static __m256i lasx_extu8_16(__m128i a) { - __m128i zero = __lsx_vldi(0); - __m128i vlo = __lsx_vilvl_b(zero, a); - __m128i vhi = __lsx_vilvh_b(zero, a); - return lasx_set_q(vhi, vlo); + return __lasx_vext2xv_hu_bu(____m256i(a)); } static __m256i lasx_ext8_16(__m128i a) { - __m128i sign = __lsx_vslti_b(a, 0); - __m128i vlo = __lsx_vilvl_b(sign, a); - __m128i vhi = __lsx_vilvh_b(sign, a); - return lasx_set_q(vhi, vlo); + return __lasx_vext2xv_h_b(____m256i(a)); } static __m256i lasx_ext16_32(__m128i a) { - __m256i tmp1; - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7); - return tmp1; + return __lasx_vext2xv_w_h(____m256i(a)); } static __m128i lasx_extracti128( __m256i a, int pos) { @@ -592,12 +577,10 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { // horizontally add 8 floats static inline float hsum_float_8(const __m256 x) { __m128 res = lasx_extractf128(x, 1); - ft_union tmp; res = __lsx_vfadd_s(res, lasx_extractf128(x, 0)); res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res)); res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0)); - tmp.i = __lsx_vpickve2gr_w(res, 0); - return tmp.f; + return ((v4f32)res)[0]; } // horizontally add 8 int32_t @@ -939,7 +922,6 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) #elif defined(__loongarch_asx) for (int i = 0; i < nb; i++) { - ft_union fi; __m256 v0 = (__m256)__lasx_xvld( x , 0); __m256 v1 = (__m256)__lasx_xvld( x , 32); __m256 v2 = (__m256)__lasx_xvld( x , 64); @@ -957,8 +939,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); __m128 tmp = max4; max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 )); - fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 ); - const float max_scalar = fi.f; + const float max_scalar = ((v4f32)max4)[0]; // Quantize these floats const float d = max_scalar / 127.f; @@ -1263,7 +1244,6 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) #elif defined(__loongarch_asx) for (int i = 0; i < nb; i++) { - ft_union ft; __m256 v0 = (__m256)__lasx_xvld( x , 0 ); __m256 v1 = (__m256)__lasx_xvld( x , 32 ); __m256 v2 = (__m256)__lasx_xvld( x , 64 ); @@ -1281,8 +1261,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); __m128 tmp = max4; max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 )); - ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 ); - const float max_scalar = ft.f; + const float max_scalar = ((v4f32)max4)[0]; // Quantize these floats const float d = max_scalar / 127.f; @@ -6154,9 +6133,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1); - ft_union fi; - fi.i = __lsx_vpickve2gr_w(acc_m, 0); - *s = hsum_float_8(acc) + fi.f ; + *s = hsum_float_8(acc) + ((v4f32)acc_m)[0]; #else const uint8_t * scales = (const uint8_t*)&utmp[0]; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index e809f05d217..59efaeb7129 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1078,29 +1078,23 @@ do { \ #define GGML_F16_STEP 32 #define GGML_F16_EPR 8 -// F16 arithmetic is not supported by AVX, so we use F32 instead +// F16 arithmetic is not supported by LASX, so we use F32 instead #define GGML_F32Cx8 __m256 #define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0) #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x)) static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) { - float tmp[8]; - - for (int i = 0; i < 8; i++) { - tmp[i] = GGML_FP16_TO_FP32(x[i]); - } - - return (__m256)__lasx_xvld(tmp, 0); + __m256i a; + memcpy(&a, x, sizeof(ggml_fp16_t) * 8); + a = __lasx_xvpermi_d(a, 0 | (1 << 4)); + return __lasx_xvfcvtl_s_h(a); } -static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) { - float arr[8]; - __lasx_xvst(y, arr, 0); - - for (int i = 0; i < 8; i++) { - x[i] = GGML_FP32_TO_FP16(arr[i]); - } +static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) { + __m256i a = __lasx_xvfcvt_h_s(y, y); + a = __lasx_xvpermi_d(a, 0 | (2 << 2)); + memcpy(x, &a, sizeof(ggml_fp16_t) * 8); } #define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x) #define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y) From 52d3ac83f537439fd71215c8df043d82a99d2a76 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Fri, 7 Feb 2025 14:57:53 +0530 Subject: [PATCH 14/58] SYCL: remove XMX info from print devices (llama/11712) --- ggml/src/ggml-sycl/ggml-sycl.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index aab34a752d4..3d24d216548 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -103,11 +103,10 @@ void print_device_detail(int id, sycl::device &device, std::string device_type) name = std::regex_replace(name, std::regex("\\(TM\\)"), ""); auto global_mem_size = prop.get_global_mem_size()/1000000; - std::string xmx = gpu_has_xmx(device) ? "yes" : "no"; - GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|%14s|\n", id, device_type.c_str(), + GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(), name.c_str(), version.c_str(), prop.get_max_compute_units(), prop.get_max_work_group_size(), prop.get_max_sub_group_size(), - global_mem_size, device.get_info().c_str(), xmx.c_str()); + global_mem_size, device.get_info().c_str()); } void ggml_backend_sycl_print_sycl_devices() { @@ -118,16 +117,16 @@ void ggml_backend_sycl_print_sycl_devices() { GGML_LOG_INFO( "| | | | " - " |Max | |Max |Global | | XMX |\n"); + " |Max | |Max |Global | |\n"); GGML_LOG_INFO( "| | | | " - " |compute|Max work|sub |mem | | or |\n"); + " |compute|Max work|sub |mem | |\n"); GGML_LOG_INFO( "|ID| Device Type| " - "Name|Version|units |group |group|size | Driver version| Tensor Cores |\n"); + "Name|Version|units |group |group|size | Driver version|\n"); GGML_LOG_INFO( "|--|-------------------|---------------------------------------|------" - "-|-------|--------|-----|-------|---------------------|--------------|\n"); + "-|-------|--------|-----|-------|---------------------|\n"); for (int id = 0; id < device_count; ++id) { sycl::device device = dpct::dev_mgr::instance().get_device(id); From 2ea46bc5a23ace1c57a6b2fffb5d89f9ada065ca Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 7 Feb 2025 04:26:03 -0600 Subject: [PATCH 15/58] vulkan: print shared memory size (llama/11719) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1c99ebe2e2c..4c962fde90d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2780,8 +2780,9 @@ static void ggml_vk_print_gpu_info(size_t idx) { std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; std::string device_name = props2.properties.deviceName.data(); - GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %s\n", - idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, matrix_cores.c_str()); + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n", + idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, + props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str()); if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n"); From 03aba11da63ff610ba3ab6c9b22a5ac9e2dd6d23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 8 Feb 2025 10:46:07 +0100 Subject: [PATCH 16/58] CUDA: fix min. version for movmatrix (llama/11751) --- ggml/src/ggml-cuda/mma.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 9788a1389a3..bbc0a35ae56 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -16,7 +16,7 @@ #include "common.cuh" -#if CUDART_VERSION >= 11800 +#if CUDART_VERSION >= 11080 static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { int ret = 0; @@ -50,7 +50,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { return ret_low | ret_high; } -#endif // CUDART_VERSION >= 11800 +#endif // CUDART_VERSION >= 11080 template From a616ea5d37b090f3cf2e4ecdafa4b52181b9c9a7 Mon Sep 17 00:00:00 2001 From: Karol Kontny <82021046+kkontny@users.noreply.github.com> Date: Sat, 8 Feb 2025 15:30:53 +0100 Subject: [PATCH 17/58] ggml: Fix data race in ggml threadpool (llama/11736) After the barrier in last iteration is executed, still the loop termination condition will be executed. However main thread can destroy the cgraph object and its nodes already, then another thread will access it, but the thing is already gone. Also trouble can happen when n_nodes == 0 or abort is called, but I'm not sure if the prior situation is possible. Last syncronization should be done after the loop to ensure the cgraph/cplan won't be accessed after the main thread exits from the function. --- ggml/src/ggml-cpu/ggml-cpu.c | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 59efaeb7129..fdb430a43cc 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -13856,9 +13856,13 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { tp->ec = GGML_STATUS_ABORTED; } - ggml_barrier(state->threadpool); + if (node_n + 1 < cgraph->n_nodes) { + ggml_barrier(state->threadpool); + } } + ggml_barrier(state->threadpool); + return 0; } From d196c351c36a776fd8b15b41b258765000d588b7 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 9 Feb 2025 01:43:51 -0600 Subject: [PATCH 18/58] vulkan: account for lookup tables when checking shared memory size (llama/11502) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 648 +++++++++++++-------------- 1 file changed, 322 insertions(+), 326 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 4c962fde90d..d32ba4efbc9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -184,12 +184,12 @@ struct vk_device_struct { size_t idx; - bool mul_mat_l; - bool mul_mat_m; - bool mul_mat_s; - bool mul_mat_id_l; - bool mul_mat_id_m; - bool mul_mat_id_s; + bool mul_mat_l[GGML_TYPE_COUNT]; + bool mul_mat_m[GGML_TYPE_COUNT]; + bool mul_mat_s[GGML_TYPE_COUNT]; + bool mul_mat_id_l[GGML_TYPE_COUNT]; + bool mul_mat_id_m[GGML_TYPE_COUNT]; + bool mul_mat_id_s[GGML_TYPE_COUNT]; // set to true to indicate that some shaders need to be compiled after the dryrun bool need_compiles {}; @@ -1378,7 +1378,33 @@ static std::array fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ return {64, 64}; }; -static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id) { +static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { + + uint32_t lut_size = 0; + switch (src0_type) { + case GGML_TYPE_IQ2_XXS: + lut_size = 8*256; + break; + case GGML_TYPE_IQ2_XS: + lut_size = 8*512; + break; + case GGML_TYPE_IQ2_S: + lut_size = 8*1024; + break; + case GGML_TYPE_IQ3_XXS: + lut_size = 4*256; + break; + case GGML_TYPE_IQ3_S: + lut_size = 4*512; + break; + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + lut_size = 4*16; + break; + default: + break; + } + // Needs to be kept up to date on shader changes const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1; const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); @@ -1388,7 +1414,13 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0; const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; - return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize; + const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " + "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported); + + return supported; } static void ggml_vk_load_shaders(vk_device& device) { @@ -1472,62 +1504,32 @@ static void ggml_vk_load_shaders(vk_device& device) { m_align = 64; s_align = 32; - // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders - // and tile sizes, this should handle 16KB, 32KB, and 48KB+. - // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders. - // But the numbers happen to work out for 32KB shared memory size that when using the medium - // size there's enough room for everything, and we assert for this. - uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); - if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { - l_warptile = m_warptile; - l_wg_denoms = m_wg_denoms; - shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); - GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); - } - if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { - // assert mul_mat_mat_id shaders will fit. - GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); - } - - shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); - if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { - if (device->properties.limits.maxComputeSharedMemorySize == 32768) { - l_warptile_mmq = m_warptile_mmq; - l_mmq_wg_denoms = m_mmq_wg_denoms; - } else { - l_warptile_mmq = s_warptile_mmq; - l_mmq_wg_denoms = s_mmq_wg_denoms; + for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) { + ggml_type t = (ggml_type)i; + // Disable medium and large matrix multiplication if not enough shared memory is available + // Check mmq warptiles as the largest configuration + // Throw an error if not enough for any matrix multiplication is available + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) { + std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl; + throw std::runtime_error("Shared memory size too small for matrix multiplication."); + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) { + device->mul_mat_m[i] = false; + device->mul_mat_l[i] = false; + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) { + device->mul_mat_l[i] = false; + } + + // Disable mul_mat_id if not enough shared memory is available + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true, t)) { + device->mul_mat_id_s[i] = false; + device->mul_mat_id_m[i] = false; + device->mul_mat_id_l[i] = false; + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true, t)) { + device->mul_mat_id_m[i] = false; + device->mul_mat_id_l[i] = false; + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true, t)) { + device->mul_mat_id_l[i] = false; } - shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); - GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); - } - if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { - // assert mul_mat_mat_id shaders will fit. - GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); - } - // Disable medium and large matrix multiplication if not enough shared memory is available - // Check mmq warptiles as the largest configuration - // Throw an error if not enough for any matrix multiplication is available - if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) { - std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl; - throw std::runtime_error("Shared memory size too small for matrix multiplication."); - } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) { - device->mul_mat_m = false; - device->mul_mat_l = false; - } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) { - device->mul_mat_l = false; - } - - // Disable mul_mat_id if not enough shared memory is available - if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) { - device->mul_mat_id_s = false; - device->mul_mat_id_m = false; - device->mul_mat_id_l = false; - } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) { - device->mul_mat_id_m = false; - device->mul_mat_id_l = false; - } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) { - device->mul_mat_id_l = false; } } @@ -1684,119 +1686,116 @@ static void ggml_vk_load_shaders(vk_device& device) { #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat_support) { // Create 6 variants, {s,m,l}x{unaligned,aligned} -#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - if (device->mul_mat ## ID ## _l) \ +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ - if (device->mul_mat ## ID ## _m) \ + if (device->mul_mat ## ID ## _m[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ - if (device->mul_mat ## ID ## _s) \ + if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ - if (device->mul_mat ## ID ## _l) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ - if (device->mul_mat ## ID ## _m) \ + if (device->mul_mat ## ID ## _m[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ - if (device->mul_mat ## ID ## _s) \ + if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ // Create 2 variants, {f16,f32} accumulator -#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ +#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->coopmat_acc_f16_support) { \ - CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ } \ if (device->coopmat_acc_f32_support) { \ - CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ } \ - CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); if (device->coopmat_acc_f16_support) { - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - } - - // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. - if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { - CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - - if (device->coopmat_acc_f16_support) { - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - } else { - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - } + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + + if (device->coopmat_acc_f16_support) { + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + } else { + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } #undef CREATE_MM2 #undef CREATE_MM @@ -1804,141 +1803,135 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->fp16) { // Create 6 variants, {s,m,l}x{unaligned,aligned} -#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - if (device->mul_mat ## ID ## _l) \ +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _m) \ + if (device->mul_mat ## ID ## _m[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _s) \ + if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _l) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ - if (device->mul_mat ## ID ## _m) \ + if (device->mul_mat ## ID ## _m[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ - if (device->mul_mat ## ID ## _s) \ + if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ // Create 2 variants, {f16,f32} accumulator -#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - - CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - - // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. - if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { - CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - } +#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); #undef CREATE_MM2 #undef CREATE_MM } else { // Create 6 variants, {s,m,l}x{unaligned,aligned} -#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - if (device->mul_mat ## ID ## _l) \ +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _m) \ + if (device->mul_mat ## ID ## _m[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _s) \ + if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _l) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ - if (device->mul_mat ## ID ## _m) \ + if (device->mul_mat ## ID ## _m[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ - if (device->mul_mat ## ID ## _s) \ + if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ - CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - - // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. - if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { - CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - } + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); #undef CREATE_MM } @@ -2623,34 +2616,36 @@ static vk_device ggml_vk_get_device(size_t idx) { // Shaders // Disable matmul tile sizes early if performance low or not supported - switch (device->vendor_id) { + for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) { + switch (device->vendor_id) { #ifndef GGML_VULKAN_RUN_TESTS - case VK_VENDOR_ID_AMD: - case VK_VENDOR_ID_INTEL: - device->mul_mat_l = false; - device->mul_mat_m = true; - device->mul_mat_s = true; - device->mul_mat_id_l = false; - device->mul_mat_id_m = true; - device->mul_mat_id_s = true; - break; - case VK_VENDOR_ID_APPLE: - device->mul_mat_l = false; - device->mul_mat_m = true; - device->mul_mat_s = false; - device->mul_mat_id_l = false; - device->mul_mat_id_m = true; - device->mul_mat_id_s = false; - break; + case VK_VENDOR_ID_AMD: + case VK_VENDOR_ID_INTEL: + device->mul_mat_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = true; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = true; + break; + case VK_VENDOR_ID_APPLE: + device->mul_mat_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = false; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = false; + break; #endif - default: - device->mul_mat_l = true; - device->mul_mat_m = true; - device->mul_mat_s = true; - device->mul_mat_id_l = true; - device->mul_mat_id_m = true; - device->mul_mat_id_s = true; - break; + default: + device->mul_mat_l[i] = true; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = true; + device->mul_mat_id_l[i] = true; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = true; + break; + } } ggml_vk_load_shaders(device); @@ -3756,31 +3751,31 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int return split_k; } -static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")"); +static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); if (ctx->device->coopmat2) { - if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) { + if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { return aligned ? mmp->a_l : mmp->l; } - if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) { + if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) { + if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) { + if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; } -static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")"); - return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align; +static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); + return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align; } static void ggml_vk_matmul( @@ -3807,31 +3802,31 @@ static void ggml_vk_matmul( ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); } -static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")"); +static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); if (ctx->device->coopmat2) { - if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) { + if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { return aligned ? mmp->a_l : mmp->l; } - if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) { + if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) { + if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) { + if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; } -static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")"); - return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align; +static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); + return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align; } static void ggml_vk_matmul_id( @@ -4012,10 +4007,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; - const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11)); + const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type)); const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned); + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type); const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline); @@ -4594,10 +4589,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t y_ne = ne11 * ne10; const uint64_t d_ne = ne21 * ne20; - const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1)); + const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type)); const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned); + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type); const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); @@ -8036,13 +8031,14 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { + ggml_type src0_type = op->src[0]->type; ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; const vk_device& device = ggml_vk_get_device(ctx->device); - if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) { + if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) { // If there's not enough shared memory for row_ids and the result tile, fallback to CPU return false; } - switch (op->src[0]->type) { + switch (src0_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_Q4_0: From 7da8fa6ffafea7bc4dbcd5bdd4ce85800e4b881c Mon Sep 17 00:00:00 2001 From: Wagner Bruna Date: Mon, 10 Feb 2025 03:08:22 -0300 Subject: [PATCH 19/58] vulkan: add environment variable GGML_VK_PREFER_HOST_MEMORY to avoid VRAM allocation (llama/11592) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d32ba4efbc9..512d3341ec0 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -167,6 +167,7 @@ struct vk_device_struct { uint32_t subgroup_size; uint32_t shader_core_count; bool uma; + bool prefer_host_memory; bool float_controls_rte_fp16; bool subgroup_size_control; @@ -1294,7 +1295,9 @@ static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk: static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { vk_buffer buf; try { - if (device->uma) { + if (device->prefer_host_memory) { + buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); + } else if (device->uma) { // Fall back to host memory type buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); } else { @@ -2199,6 +2202,9 @@ static vk_device ggml_vk_get_device(size_t idx) { device->physical_device = physical_devices[dev_num]; const std::vector ext_props = device->physical_device.enumerateDeviceExtensionProperties(); + const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY"); + device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr; + bool fp16_storage = false; bool fp16_compute = false; bool maintenance4_support = false; From 94eaf9741878d14399c674a60916587eaee64c68 Mon Sep 17 00:00:00 2001 From: Danny Milosavljevic Date: Mon, 10 Feb 2025 07:17:21 +0100 Subject: [PATCH 20/58] vulkan: Make Vulkan optional at runtime (ggml/11493). (llama/11494) Co-authored-by: Jeff Bolz --- ggml/include/ggml-vulkan.h | 2 -- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 16 ++++++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/ggml/include/ggml-vulkan.h b/ggml/include/ggml-vulkan.h index 53cdba072c2..ed5ea5f798c 100644 --- a/ggml/include/ggml-vulkan.h +++ b/ggml/include/ggml-vulkan.h @@ -10,8 +10,6 @@ extern "C" { #define GGML_VK_NAME "Vulkan" #define GGML_VK_MAX_DEVICES 16 -GGML_BACKEND_API void ggml_vk_instance_init(void); - // backend API GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 512d3341ec0..bffe95086af 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2793,14 +2793,12 @@ static void ggml_vk_print_gpu_info(size_t idx) { static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions); static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); -void ggml_vk_instance_init() { +static void ggml_vk_instance_init() { if (vk_instance_initialized) { return; } VK_LOG_DEBUG("ggml_vk_instance_init()"); - vk_instance_initialized = true; - uint32_t api_version = vk::enumerateInstanceVersion(); if (api_version < VK_API_VERSION_1_2) { @@ -2851,6 +2849,7 @@ void ggml_vk_instance_init() { GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n"); } vk_instance.instance = vk::createInstance(instance_create_info); + vk_instance_initialized = true; size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); @@ -2875,7 +2874,7 @@ void ggml_vk_instance_init() { // Make sure at least one device exists if (devices.empty()) { std::cerr << "ggml_vulkan: Error: No devices found." << std::endl; - GGML_ABORT("fatal error"); + return; } // Default to using all dedicated GPUs @@ -8350,8 +8349,13 @@ ggml_backend_reg_t ggml_backend_vk_reg() { /* .iface = */ ggml_backend_vk_reg_i, /* .context = */ nullptr, }; - - return ® + try { + ggml_vk_instance_init(); + return ® + } catch (const vk::SystemError& e) { + VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what()); + return nullptr; + } } // Extension availability From 06db520a1eb4779c2a2f46ee1c5a4b55204215de Mon Sep 17 00:00:00 2001 From: Maxim Evtush <154841002+maximevtush@users.noreply.github.com> Date: Mon, 10 Feb 2025 23:21:31 +0100 Subject: [PATCH 21/58] fix: typos in documentation files (llama/11791) * Update ggml.c * Update arg.cpp * Update speculative.h --- ggml/src/ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 3b486154211..e9f3420c294 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1379,7 +1379,7 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso (t0->nb[3] == t1->nb[3]); } -// check if t1 can be represented as a repeatition of t0 +// check if t1 can be represented as a repetition of t0 bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); From 8903621cf26ba5e28f3696954579c20950d13959 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 11 Feb 2025 00:17:22 +0100 Subject: [PATCH 22/58] CUDA: use arch list for compatibility check (llama/11775) * CUDA: use arch list for feature availability check --------- Co-authored-by: Diego Devesa --- ggml/src/ggml-common.h | 2 - ggml/src/ggml-cuda/common.cuh | 65 ++++++++++++++++++++++++++++++--- ggml/src/ggml-cuda/convert.cu | 2 +- ggml/src/ggml-cuda/ggml-cuda.cu | 12 +++--- ggml/src/ggml-cuda/mmq.cu | 9 +++-- ggml/src/ggml-cuda/mmq.cuh | 14 ++++--- 6 files changed, 80 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index f13fd4dea6f..6c02b69ea23 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -473,7 +473,6 @@ GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128) 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, GGML_TABLE_END() -//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics GGML_TABLE_BEGIN(uint64_t, ksigns64, 128) 0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff, 0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff, @@ -508,7 +507,6 @@ GGML_TABLE_BEGIN(uint64_t, ksigns64, 128) 0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff, 0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff, GGML_TABLE_END() -//#endif GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 174916bc970..2a324442852 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -71,6 +71,47 @@ #define GGML_CUDA_CC_QY1 210 #define GGML_CUDA_CC_QY2 220 +#ifdef __CUDA_ARCH_LIST__ +constexpr bool ggml_cuda_has_arch_impl(int) { + return false; +} + +template +constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) { + return arch == first || ggml_cuda_has_arch_impl(arch, rest...); +} + +constexpr bool ggml_cuda_has_arch(const int arch) { + return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__); +} + +constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) { + if (cur == 0) { + GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch); + } + return cur; +} + +template +constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) { + if (first <= arch && first > cur) { + return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...); + } else { + return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...); + } +} + +constexpr int ggml_cuda_highest_compiled_arch(const int arch) { + return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__); +} +#else +static int ggml_cuda_highest_compiled_arch(const int arch) { + return arch; +} +#endif // __CUDA_ARCH_LIST__ + +// --------------------------------------------------------------------------------------------------------- + #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses #if defined(_MSC_VER) @@ -162,18 +203,32 @@ typedef float2 dfloat2; #define FLASH_ATTN_AVAILABLE #endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) -static constexpr bool fast_fp16_available(const int cc) { +static bool fp16_available(const int cc) { + return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL; +} + +static bool fast_fp16_available(const int cc) { + return fp16_available(cc) && cc != 610; +} + +// To be used for feature selection of external libraries, e.g. cuBLAS. +static bool fast_fp16_hardware_available(const int cc) { return cc >= GGML_CUDA_CC_PASCAL && cc != 610; } -// Any FP16 tensor cores are available. -static constexpr bool fp16_mma_available(const int cc) { +// Any FP16 tensor core instructions are available for ggml code. +static bool fp16_mma_available(const int cc) { + return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA; +} + +// To be used for feature selection of external libraries, e.g. cuBLAS. +static bool fp16_mma_hardware_available(const int cc) { return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA; } // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. -static constexpr bool new_mma_available(const int cc) { - return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING; +static bool new_mma_available(const int cc) { + return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; } static constexpr __device__ int ggml_cuda_get_physical_warp_size() { diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 5b0dfacefc9..795b720d60b 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -599,7 +599,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { case GGML_TYPE_Q5_1: return dequantize_block_cuda; case GGML_TYPE_Q8_0: - if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_PASCAL) { + if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) { return dequantize_block_q8_0_f16_cuda; } return dequantize_block_cuda; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 4dbaefdbafd..c95728b08bf 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1867,14 +1867,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor const int cc = ggml_cuda_info().devices[id].cc; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); - any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); + any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc); } } else { const int cc = ggml_cuda_info().devices[ctx.device].cc; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); - any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); + any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc); } // debug helpers @@ -3205,8 +3205,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { return true; } - const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; - return cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; + return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) && + op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; } case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 45212f66c00..5dacd131ed5 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -18,7 +18,7 @@ void ggml_cuda_op_mul_mat_q( const int64_t stride00 = ne00 / ggml_blck_size(src0->type); int id = ggml_cuda_get_device(); - const int compute_capability = ggml_cuda_info().devices[id].cc; + const int cc = ggml_cuda_info().devices[id].cc; // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the kernel writes into @@ -27,7 +27,8 @@ void ggml_cuda_op_mul_mat_q( // The stream-k decomposition is only faster for recent NVIDIA GPUs. // Also its fixup needs to allocate a temporary buffer in the memory pool. // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. - const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11; + const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && + cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11; const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k}; switch (src0->type) { @@ -136,7 +137,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return true; } - if (cc < GGML_CUDA_CC_DP4A) { + if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) { return false; } @@ -145,7 +146,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { #endif //GGML_CUDA_FORCE_MMQ if (cc < GGML_CUDA_CC_OFFSET_AMD) { - return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 7a2c4d85b79..5391542086c 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -86,12 +86,13 @@ struct tile_x_sizes { int sc; }; -static constexpr int get_mmq_x_max_host(const int cc) { +static int get_mmq_x_max_host(const int cc) { return new_mma_available(cc) ? 128 : + ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? #ifdef GGML_CUDA_FORCE_MMQ - cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64; + 128 : 64; #else - cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64; + MMQ_DP4A_MAX_BATCH_SIZE : 64; #endif // GGML_CUDA_FORCE_MMQ } @@ -119,8 +120,9 @@ static constexpr __device__ int get_mmq_x_max_device() { #endif // NEW_MMA_AVAILABLE } -static constexpr int get_mmq_y_host(const int cc) { - return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64); +static int get_mmq_y_host(const int cc) { + return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : + (ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64); } static constexpr __device__ int get_mmq_y_device() { @@ -2828,7 +2830,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda const int mmq_x_max = get_mmq_x_max_host(cc); const int mmq_y = get_mmq_y_host(cc); const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y; - const bool use_stream_k = cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD; + const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD; int mmq_x_best = 0; int nparts_best = INT_MAX; From 7ebc835a1f470ff51da8bb1d09843aeeac968901 Mon Sep 17 00:00:00 2001 From: Sheldon Robinson Date: Tue, 11 Feb 2025 10:55:45 -0500 Subject: [PATCH 23/58] Fix #11802: Compile bug - RegQueryValueExA changed to RegQueryValueEx (llama/11803) * Fix #11802: Compile bug - RegQueryValueExA changed to RegQueryValueEx * Fix #11802: PR #11803 - keep RegQueryValueExA, remove TEXT macro, description needs to be ANSI string --- ggml/src/ggml-cpu/ggml-cpu.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 2ccb4b472d6..87d7ce530ce 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -284,14 +284,14 @@ struct ggml_backend_cpu_device_context { &hKey) == ERROR_SUCCESS) { DWORD cpu_brand_size = 0; if (RegQueryValueExA(hKey, - TEXT("ProcessorNameString"), + "ProcessorNameString", NULL, NULL, NULL, &cpu_brand_size) == ERROR_SUCCESS) { description.resize(cpu_brand_size); if (RegQueryValueExA(hKey, - TEXT("ProcessorNameString"), + "ProcessorNameString", NULL, NULL, (LPBYTE)&description[0], // NOLINT From 523f00bd13cfc55365ba7582324b2159584eccd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 12 Feb 2025 13:16:39 +0100 Subject: [PATCH 24/58] CUDA: fix CUDART_VERSION checks (llama/11821) --- ggml/src/ggml-cuda/common.cuh | 4 ++-- ggml/src/ggml-cuda/ggml-cuda.cu | 6 ++++-- ggml/src/ggml-cuda/sum.cu | 4 ++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 2a324442852..fd4dcfa941d 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -165,11 +165,11 @@ static const char * cu_get_error_str(CUresult err) { #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) #endif -#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) +#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) #define GGML_CUDA_ASSUME(x) __builtin_assume(x) #else #define GGML_CUDA_ASSUME(x) -#endif // CUDART_VERSION >= 11100 +#endif // CUDART_VERSION >= 11010 #ifdef GGML_CUDA_F16 typedef half dfloat; // dequantize float diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c95728b08bf..6d5d9aa5470 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2840,7 +2840,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) { return false; } -#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) +#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); if (err != cudaSuccess) { // clear the error @@ -2852,8 +2852,10 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) { } return true; #else + GGML_UNUSED(buffer); + GGML_UNUSED(size); return false; -#endif +#endif // CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) } void ggml_backend_cuda_unregister_host_buffer(void * buffer) { diff --git a/ggml/src/ggml-cuda/sum.cu b/ggml/src/ggml-cuda/sum.cu index e0dafc1d204..f9589080a0c 100644 --- a/ggml/src/ggml-cuda/sum.cu +++ b/ggml/src/ggml-cuda/sum.cu @@ -1,6 +1,6 @@ -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700 +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 #define USE_CUB -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700 +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 #ifdef USE_CUB #include From 2686930ca4c59533e0262e185a60c45e7d6aee5f Mon Sep 17 00:00:00 2001 From: Weizhao Ouyang Date: Wed, 12 Feb 2025 20:22:58 +0800 Subject: [PATCH 25/58] ggml-cpu: Fix duplicate MATMUL_INT8 (llama/11817) Signed-off-by: Weizhao Ouyang --- ggml/src/ggml-cpu/ggml-cpu.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 87d7ce530ce..be4eadcd021 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -534,9 +534,6 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r if (ggml_cpu_has_dotprod()) { features.push_back({ "DOTPROD", "1" }); } - if (ggml_cpu_has_matmul_int8()) { - features.push_back({ "MATMUL_INT8", "1" }); - } if (ggml_cpu_get_sve_cnt() > 0) { static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt()); features.push_back({ "SVE_CNT", sve_cnt.c_str() }); From 04063cbff21a03aa2453a8ae741b1dfb02950dc6 Mon Sep 17 00:00:00 2001 From: Richard Date: Wed, 12 Feb 2025 13:57:33 +0000 Subject: [PATCH 26/58] ggml : fix multi-threaded clamp_f32 (llama/11824) * Bug fix for clamp_f32 When using tensors larger than 1d clamp operation does not work due to the restriction of returning if ith is not 0. * Bug fix for clamp_f32 * Bug fix for clamp_f32 --- ggml/src/ggml-cpu/ggml-cpu.c | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index fdb430a43cc..fcbb5c233f0 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -9074,10 +9074,6 @@ static void ggml_compute_forward_clamp_f32( const struct ggml_tensor * src0 = dst->src[0]; - if (params->ith != 0) { - return; - } - float min; float max; memcpy(&min, (float *) dst->op_params + 0, sizeof(float)); From e288e300838595b2b16ed2429afff72b42d63244 Mon Sep 17 00:00:00 2001 From: bandoti <141645996+bandoti@users.noreply.github.com> Date: Wed, 12 Feb 2025 10:06:53 -0400 Subject: [PATCH 27/58] cleanup: fix compile warnings associated with gnu_printf (llama/11811) --- ggml/include/ggml.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 5bd8d9c8b50..dd0c6a96eae 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -198,7 +198,7 @@ #ifndef __GNUC__ # define GGML_ATTRIBUTE_FORMAT(...) -#elif defined(__MINGW32__) +#elif defined(__MINGW32__) && !defined(__clang__) # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) #else # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) From 5590b2bbc499111b3107b481be8fc4cc1f0abdf1 Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 12 Feb 2025 17:25:03 +0100 Subject: [PATCH 28/58] HIP: Switch to std::vector in rocblas version check (llama/11820) --- ggml/src/ggml-cuda/ggml-cuda.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 6d5d9aa5470..6ea41577768 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -178,11 +178,11 @@ static ggml_cuda_device_info ggml_cuda_init() { int major_version = 0; size_t version_length = 0; if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) { - std::string version(version_length, '\0'); + std::vector version(version_length+1, '\0'); if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) { - version.resize(::strlen(version.c_str())); + version.resize(::strlen(version.data())); int parsed_value = 0; - if (std::from_chars(version.c_str(), version.c_str() + version.length(), parsed_value).ec == std::errc()) { + if (std::from_chars(version.data(), version.data() + version.size(), parsed_value).ec == std::errc()) { major_version = parsed_value; } } From 5b621837cb9e8825c18058739f2cf1630ac854e8 Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 12 Feb 2025 22:25:28 +0100 Subject: [PATCH 29/58] HIP: Remove GCN from list of devices that avoid MMQ (llama/11831) --- ggml/src/ggml-cuda/mmq.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 5dacd131ed5..10f2ebb1cb4 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } - return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } From 4151a81836727253cfe2c1e7d455a84c7408512d Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Thu, 13 Feb 2025 00:33:45 +0100 Subject: [PATCH 30/58] ggml : x2 speed for WASM by optimizing SIMD (llama/11453) * ggml : x2 speed for WASM by optimizing SIMD * fix bad merging * rm trailing spaces * rm redundant clamp * better quantize_row_q8_K Co-authored-by: camel-cdr * remove memset that causes buffer overflow Co-authored-by: camel-cdr --------- Co-authored-by: camel-cdr --- ggml/src/ggml-cpu/ggml-cpu-quants.c | 721 +++++++++++++++++++++++++++- 1 file changed, 704 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 27ec1493565..1b4bd66e80c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -742,7 +742,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); } } -#elif defined(__wasm_simd128__) +#elif defined __wasm_simd128__ for (int i = 0; i < nb; i++) { v128_t srcv [8]; v128_t asrcv[8]; @@ -1030,7 +1030,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); } -#elif defined(__wasm_simd128__) +#elif defined __wasm_simd128__ for (int i = 0; i < nb; i++) { v128_t srcv [8]; v128_t asrcv[8]; @@ -1644,7 +1644,87 @@ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -1 //===================================== Q8_K ============================================== void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { +#ifdef __wasm_simd128__ + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + block_q8_K * restrict yc = y; // Cast to proper type + + for (int i = 0; i < nb; i++) { + const float * x_block = x + i * QK_K; + + v128_t min_vec = wasm_v128_load(x_block); + v128_t max_vec = min_vec; + + for (int j = 4; j < QK_K; j += 4) { + v128_t x_vec = wasm_v128_load(x_block + j); + max_vec = wasm_f32x4_pmax(max_vec, x_vec); + min_vec = wasm_f32x4_pmin(min_vec, x_vec); + } + max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1)); + max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2)); + min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1)); + min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2)); + float max = wasm_f32x4_extract_lane(max_vec, 0); + float min = wasm_f32x4_extract_lane(min_vec, 0); + float amax = -min > max ? min : max; + + if (amax == 0.0f) { + yc[i].d = 0.0f; + const v128_t zero = wasm_i8x16_splat(0); + for (int j = 0; j < QK_K; j += 16) { + wasm_v128_store(yc[i].qs + j, zero); + } + continue; + } + + const float iscale = -127.0f / amax; + const v128_t scale_vec = wasm_f32x4_splat(iscale); + + // Process 16 elements per iteration + for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) { + // Load and quantize 16 floats + v128_t x0 = wasm_v128_load(x_block + j); + v128_t x1 = wasm_v128_load(x_block + j + 4); + v128_t x2 = wasm_v128_load(x_block + j + 8); + v128_t x3 = wasm_v128_load(x_block + j + 12); + + v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec)); + v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec)); + v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec)); + v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec)); + + // Convert to i32 with saturation + v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0); + v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1); + v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2); + v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3); + + // Pack into 16 i8 values + v128_t i8 = wasm_i8x16_narrow_i16x8( + wasm_i16x8_narrow_i32x4(i0, i1), + wasm_i16x8_narrow_i32x4(i2, i3) + ); + wasm_v128_store(yc[i].qs + j, i8); + + // Calculate bsums using SIMD + v128_t sum16 = wasm_i16x8_add( + wasm_i16x8_extend_low_i8x16(i8), + wasm_i16x8_extend_high_i8x16(i8) + ); + v128_t sum32 = wasm_i32x4_add( + wasm_i32x4_extend_low_i16x8(sum16), + wasm_i32x4_extend_high_i16x8(sum16) + ); + sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1)); + sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2)); + yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0); + } + + yc[i].d = 1.0f / iscale; + } +#else quantize_row_q8_K_ref(x, y, k); +#endif } //===================================== Dot products ================================= @@ -2002,6 +2082,94 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined __wasm_simd128__ + v128_t sumv = wasm_f32x4_splat(0.0f); + + const v128_t m4b = wasm_i8x16_splat(0x0F); + const v128_t s8b = wasm_i8x16_splat(0x8); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // Load and process x0 + v128_t v0_0 = wasm_v128_load(x0->qs); + v128_t v0_0l = wasm_v128_and(v0_0, m4b); + v128_t v0_0h = wasm_u8x16_shr(v0_0, 4); + v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b); + v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b); + + // Load y0 vectors + v128_t y0_l = wasm_v128_load(y0->qs); + v128_t y0_h = wasm_v128_load(y0->qs + 16); + + // Extend to i16x8 and compute dot products + v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls); + v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls); + v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs); + v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs); + + v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l); + v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l); + v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h); + v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h); + + v128_t dp0 = wasm_i32x4_add( + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx0l, dy0ll), + wasm_i32x4_dot_i16x8(dx0h, dy0lh) + ), + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx0hl, dy0hl), + wasm_i32x4_dot_i16x8(dx0hh, dy0hh) + ) + ); + + // Load and process x1 + v128_t v0_1 = wasm_v128_load(x1->qs); + v128_t v0_1l = wasm_v128_and(v0_1, m4b); + v128_t v0_1h = wasm_u8x16_shr(v0_1, 4); + v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b); + v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b); + + // Load y1 vectors + v128_t y1_l = wasm_v128_load(y1->qs); + v128_t y1_h = wasm_v128_load(y1->qs + 16); + + // Extend to i16x8 and compute dot products + v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls); + v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls); + v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs); + v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs); + + v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l); + v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l); + v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h); + v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h); + + v128_t dp1 = wasm_i32x4_add( + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx1l, dy1ll), + wasm_i32x4_dot_i16x8(dx1h, dy1lh) + ), + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx1hl, dy1hl), + wasm_i32x4_dot_i16x8(dx1hh, dy1hh) + ) + ); + + // Accumulate results with scaling + float scale0 = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d); + float scale1 = GGML_FP16_TO_FP32(x1->d) * GGML_FP16_TO_FP32(y1->d); + + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0))); + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2688,10 +2856,10 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__wasm_simd128__) +#elif defined __wasm_simd128__ v128_t sumv = wasm_f32x4_splat(0.0f); - uint32_t qh; + uint32_t qh_; uint64_t tmp[4]; // TODO: check if unrolling this is better @@ -2702,12 +2870,12 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r const v128_t m4b = wasm_i8x16_splat(0x0F); // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); + memcpy(&qh_, x0->qh, sizeof(qh_)); - tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_1[(qh >> 24) ]; + tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF]; + tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF]; + tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF]; + tmp[3] = table_b2b_1[(qh_ >> 24) ]; const v128_t qhl = wasm_v128_load(tmp + 0); const v128_t qhh = wasm_v128_load(tmp + 2); @@ -3049,12 +3217,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; -#elif defined(__wasm_simd128__) +#elif defined __wasm_simd128__ v128_t sumv = wasm_f32x4_splat(0.0f); float summs = 0.0f; - uint32_t qh; + uint32_t qh_; uint64_t tmp[4]; // TODO: check if unrolling this is better @@ -3067,12 +3235,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r const v128_t m4b = wasm_i8x16_splat(0x0F); // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); + memcpy(&qh_, x0->qh, sizeof(qh_)); - tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_0[(qh >> 24) ]; + tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF]; + tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF]; + tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF]; + tmp[3] = table_b2b_0[(qh_ >> 24) ]; const v128_t qhl = wasm_v128_load(tmp + 0); const v128_t qhh = wasm_v128_load(tmp + 2); @@ -3565,6 +3733,45 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined __wasm_simd128__ + v128_t sumv = wasm_f32x4_splat(0.0f); + + for (; ib < nb; ++ib) { + const block_q8_0 * restrict x0 = &x[ib]; + const block_q8_0 * restrict y0 = &y[ib]; + + const v128_t x0_0 = wasm_v128_load(x0->qs); + const v128_t x0_1 = wasm_v128_load(x0->qs + 16); + const v128_t y0_0 = wasm_v128_load(y0->qs); + const v128_t y0_1 = wasm_v128_load(y0->qs + 16); + + // Extend 8-bit to 16-bit + const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0); + const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0); + const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1); + const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1); + + const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0); + const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0); + const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1); + const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1); + + // Compute dot products + const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l); + const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h); + const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l); + const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h); + + // Sum all dot products + const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1)); + + // Convert to float and accumulate + const float scale = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d); + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -4439,6 +4646,106 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = hsum_float_8(acc); +#elif defined __wasm_simd128__ + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + // Vectorized summs calculation + v128_t summs_vec = wasm_i32x4_splat(0); + { + v128_t sc_vec = wasm_v128_load(sc); + v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4); + + v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper); + v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper); + + v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]); + v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]); + + summs_vec = wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1), + wasm_i32x4_dot_i16x8(sc_high, bsums2)), + summs_vec + ); + + summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1)); + summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2)); + } + int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0); + + // Vectorized isum calculation + int32_t isum = 0; + const uint8_t * sc_ptr = sc; + const int k_iters = QK_K/128; + + for (int k = 0; k < k_iters; ++k) { + v128_t isum_vec = wasm_i32x4_splat(0); + int shift = 0; + + for (int j = 0; j < 4; ++j) { + const int d0 = (sc_ptr[0] & 0xF); + const int d1 = (sc_ptr[1] & 0xF); + sc_ptr += 2; + + // Process first 16 elements + v128_t q2_0 = wasm_v128_load(q2); + v128_t q8_0 = wasm_v128_load(q8); + v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift); + v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03)); + + // Process next 16 elements + v128_t q2_1 = wasm_v128_load(q2 + 16); + v128_t q8_1 = wasm_v128_load(q8 + 16); + v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift); + v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03)); + + // Calculate dot products + v128_t p0 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q8_0), + wasm_i16x8_extend_low_i8x16(q2_bits_0) + ); + v128_t p1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q8_0), + wasm_i16x8_extend_high_i8x16(q2_bits_0) + ); + v128_t p2 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q8_1), + wasm_i16x8_extend_low_i8x16(q2_bits_1) + ); + v128_t p3 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q8_1), + wasm_i16x8_extend_high_i8x16(q2_bits_1) + ); + + // Accumulate scaled results + v128_t scaled = wasm_i32x4_add( + wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)), + wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1)) + ); + + isum_vec = wasm_i32x4_add(isum_vec, scaled); + q8 += 32; + shift += 2; + } + q2 += 32; + + // Horizontal sum of isum_vec + isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1)); + isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2)); + isum += wasm_i32x4_extract_lane(isum_vec, 0); + } + + const float dall = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf += dall * isum - dmin * summs; + } + + *s = sumf; + #elif defined __riscv_v_intrinsic float sumf = 0; @@ -5121,6 +5428,94 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = hsum_float_8(acc); +#elif defined __wasm_simd128__ + int8_t aux8[QK_K]; + float sums[8] = {0}; + uint32_t auxs[4]; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + // Process blocks with SIMD + int8_t * a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int shift = 0; shift <= 6; shift += 2) { + v128_t v_m = wasm_i8x16_splat(m); + for (int l = 0; l < 32; l += 16) { + v128_t v_q3 = wasm_v128_load(q3 + l); + v128_t v_shift = wasm_i8x16_shr(v_q3, shift); + v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03)); + + v128_t v_hm = wasm_v128_load(hm + l); + v128_t v_mask = wasm_v128_and(v_hm, v_m); + v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0)); + + v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask))); + wasm_v128_store(a + l, v_low2); + } + a += 32; + m <<= 1; + } + q3 += 32; + } + + // Extract scales + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + const int8_t * scales = (const int8_t *)auxs; + + // SIMD dot product with register accumulators + v128_t v_acc0 = wasm_i32x4_splat(0); + v128_t v_acc1 = wasm_i32x4_splat(0); + a = aux8; + for (int j = 0; j < QK_K/16; ++j) { + const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32); + + // Process 16 elements per iteration + for (int k = 0; k < 2; ++k) { + const v128_t v_q8 = wasm_i16x8_load8x8(q8); + const v128_t v_a = wasm_i16x8_load8x8(a); + + v128_t v_prod = wasm_i16x8_mul(v_q8, v_a); + v_prod = wasm_i16x8_mul(v_prod, v_scale); + + v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod)); + v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod)); + + q8 += 8; + a += 8; + } + } + + // Accumulate results + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const v128_t v_d = wasm_f32x4_splat(d); + v128_t v_sum = wasm_f32x4_add( + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d), + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d) + ); + + // Accumulate into sums vector + wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum)); + } + + // Horizontal sum + v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4)); + sumf = wasm_f32x4_extract_lane(v_sum, 0) + + wasm_f32x4_extract_lane(v_sum, 1) + + wasm_f32x4_extract_lane(v_sum, 2) + + wasm_f32x4_extract_lane(v_sum, 3); + + *s = sumf; + #elif defined __riscv_v_intrinsic uint32_t aux[3]; @@ -5646,7 +6041,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r } } *s = sumf; -#elif __ARM_NEON +#elif defined __ARM_NEON const uint8x16_t m4b = vdupq_n_u8(0xf); const int32x4_t mzero = vdupq_n_s32(0); @@ -5709,6 +6104,107 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = sumf; +#elif defined __wasm_simd128__ + const uint8_t * scales = (const uint8_t*)&utmp[0]; + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Process scales and mins + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + // Sum mins * q8sums + int32_t sumi = 0; + const int16_t * restrict q8sums = y[i].bsums; + const uint8_t * m = (const uint8_t *)&utmp[2]; + for (int j = 0; j < 16; j += 2) { + sumi += (q8sums[j] + q8sums[j+1]) * m[j/2]; + } + sumf -= dmin * sumi; + + int32_t sumi1 = 0; + int32_t sumi2 = 0; + + for (int j = 0; j < QK_K/64; ++j) { + // Load 64 4-bit weights (32 bytes) + const v128_t q4x0 = wasm_v128_load(q4); + const v128_t q4x1 = wasm_v128_load(q4 + 16); + q4 += 32; + + // Split into low/high nibbles + const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F)); + const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4); + const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F)); + const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4); + + // Load 64 8-bit values (64 bytes) + const v128_t q8x0 = wasm_v128_load(q8); + const v128_t q8x1 = wasm_v128_load(q8 + 16); + const v128_t q8x2 = wasm_v128_load(q8 + 32); + const v128_t q8x3 = wasm_v128_load(q8 + 48); + q8 += 64; + + // Low nibble products + v128_t vacc1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4l0), + wasm_i16x8_extend_low_i8x16(q8x0) + ); + vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4l0), + wasm_i16x8_extend_high_i8x16(q8x0) + )); + vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4l1), + wasm_i16x8_extend_low_i8x16(q8x1) + )); + vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4l1), + wasm_i16x8_extend_high_i8x16(q8x1) + )); + + // High nibble products + v128_t vacc2 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4h0), + wasm_i16x8_extend_low_i8x16(q8x2) + ); + vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4h0), + wasm_i16x8_extend_high_i8x16(q8x2) + )); + vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4h1), + wasm_i16x8_extend_low_i8x16(q8x3) + )); + vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4h1), + wasm_i16x8_extend_high_i8x16(q8x3) + )); + + // Accumulate scaled results + int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) + + wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3); + sumi1 += vacc1_sum * scales[2*j]; + + int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) + + wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3); + sumi2 += vacc2_sum * scales[2*j+1]; + } + + sumf += d * (sumi1 + sumi2); + } + + *s = sumf; + #elif defined __AVX2__ const __m256i m4 = _mm256_set1_epi8(0xF); @@ -6459,6 +6955,118 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = hsum_float_8(acc) + summs; +#elif defined __wasm_simd128__ + //const uint8_t * scales = (const uint8_t*)&utmp[0]; + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + // Process scales and mins + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + // Sum mins * q8sums + int32_t sumi_mins = 0; + const int16_t * restrict q8sums = y[i].bsums; + const uint8_t * m = (const uint8_t *)&utmp[2]; + for (int j = 0; j < 16; j += 2) { + sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2]; + } + sumf -= dmin * sumi_mins; // Correct subtraction + + v128_t qh0 = wasm_v128_load(qh); + v128_t qh1 = wasm_v128_load(qh + 16); + const uint8_t * sc = (const uint8_t *)utmp; + + int32_t sumi = 0; + + for (int j = 0; j < QK_K/64; ++j) { + const int shift = j * 2; + v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift); + v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift); + + v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4); + v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3); + v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4); + v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3); + + v128_t q5_0 = wasm_v128_load(q5); + v128_t q5_1 = wasm_v128_load(q5 + 16); + q5 += 32; + + v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0); + v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0); + v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1); + v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1); + + v128_t q8_0 = wasm_v128_load(q8); + v128_t q8_1 = wasm_v128_load(q8 + 16); + v128_t q8_2 = wasm_v128_load(q8 + 32); + v128_t q8_3 = wasm_v128_load(q8 + 48); + q8 += 64; + + // Process low quants + v128_t pl0 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5l_0), + wasm_i16x8_extend_low_i8x16(q8_0) + ); + pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5l_0), + wasm_i16x8_extend_high_i8x16(q8_0) + )); + v128_t pl1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5l_1), + wasm_i16x8_extend_low_i8x16(q8_1) + ); + pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5l_1), + wasm_i16x8_extend_high_i8x16(q8_1) + )); + v128_t sum_low = wasm_i32x4_add(pl0, pl1); + + // Process high quants + v128_t ph0 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5h_0), + wasm_i16x8_extend_low_i8x16(q8_2) + ); + ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5h_0), + wasm_i16x8_extend_high_i8x16(q8_2) + )); + v128_t ph1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5h_1), + wasm_i16x8_extend_low_i8x16(q8_3) + ); + ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5h_1), + wasm_i16x8_extend_high_i8x16(q8_3) + )); + v128_t sum_high = wasm_i32x4_add(ph0, ph1); + + // Accumulate with scale factors + int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) + + wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3); + int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) + + wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3); + + sumi += sl * sc[2*j] + sh * sc[2*j+1]; + } + + sumf += d * sumi; + } + + *s = sumf; + #elif defined __riscv_v_intrinsic const uint8_t * scales = (const uint8_t*)&utmp[0]; @@ -7122,6 +7730,85 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = hsum_float_8(acc); +#elif defined __wasm_simd128__ + int8_t aux8[QK_K] __attribute__((aligned(16))); + int32_t aux32[8] __attribute__((aligned(16))) = {0}; + float sums[8] __attribute__((aligned(16))) = {0}; + + for (int i = 0; i < nb; ++i) { + // Unpack 6-bit quantized data into aux8 (unchanged) + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + int8_t * a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + + const int8_t * restrict a_ptr = aux8; + const int8_t * restrict q8 = y[i].qs; + v128_t acc0 = wasm_i32x4_splat(0); + v128_t acc1 = wasm_i32x4_splat(0); + + for (int j = 0; j < QK_K/16; ++j) { + const int scale = x[i].scales[j]; + const v128_t vscale = wasm_i32x4_splat(scale); + + // Load 16 elements from a and q8 + const v128_t a_vec = wasm_v128_load(a_ptr); + const v128_t q8_vec = wasm_v128_load(q8); + + // Process low 8 elements + v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec); + v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec); + v128_t prod_low = wasm_i16x8_mul(a_low, q8_low); + v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low); + v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low); + + // Process high 8 elements + v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec); + v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec); + v128_t prod_high = wasm_i16x8_mul(a_high, q8_high); + v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high); + v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high); + + // Scale and accumulate + prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale); + prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale); + prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale); + prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale); + + acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo)); + acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi)); + + a_ptr += 16; + q8 += 16; + } + + // Store accumulated results + wasm_v128_store(&aux32[0], acc0); + wasm_v128_store(&aux32[4], acc1); + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) { + sums[l] += d * aux32[l]; + } + } + + // Sum final results + float sumf = 0; + for (int l = 0; l < 8; ++l) { + sumf += sums[l]; + } + *s = sumf; + #elif defined __riscv_v_intrinsic float sumf = 0; From 322965d87c1726ec24c5540991f84cb2266e76f0 Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Thu, 13 Feb 2025 01:02:38 +0100 Subject: [PATCH 31/58] ggml-cpu : add chunking support to mul_mat_id (llama/11666) * ggml-cpu : add chunking support to mul_mat_id * allocate chunk counter in wdata parallelize src1 quantization by column to allows parallelization even when there is only one row * disable for arm * cleanup * better way to disable for arm * fix uninitialized counter when using 1 thread only * revert test-backend-ops changes --- ggml/src/ggml-cpu/ggml-cpu.c | 269 ++++++++++++++++++++++++----------- 1 file changed, 184 insertions(+), 85 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index fcbb5c233f0..0cbf8318bed 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -7,10 +7,8 @@ #include "ggml-cpu-impl.h" #include "ggml-cpu.h" #include "ggml-impl.h" -#include "ggml-quants.h" #include "ggml-cpu-quants.h" #include "ggml-threading.h" -#include "amx/amx.h" #include "ggml.h" #if defined(_MSC_VER) || defined(__MINGW32__) @@ -1291,7 +1289,7 @@ struct ggml_threadpool { atomic_int n_graph; // incremented when there is work to be done (i.e each graph) atomic_int GGML_CACHE_ALIGN n_barrier; atomic_int GGML_CACHE_ALIGN n_barrier_passed; - atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. + atomic_int GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. // these are atomic as an annotation for thread-sanitizer atomic_bool stop; // Used for stopping the threadpool altogether @@ -7490,6 +7488,7 @@ UseGgmlGemm1:; if (src1->type != vec_dot_type) { char * wdata = params->wdata; + const size_t nbw0 = ggml_type_size(vec_dot_type); const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); const size_t nbw2 = nbw1*ne11; const size_t nbw3 = nbw2*ne12; @@ -7497,6 +7496,7 @@ UseGgmlGemm1:; assert(params->wsize >= ne13*nbw3); GGML_ASSERT(src1->type == GGML_TYPE_F32); + #if 0 for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i11 = ith; i11 < ne11; i11 += nth) { @@ -7506,6 +7506,20 @@ UseGgmlGemm1:; } } } + #else + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + size_t bs = ggml_blck_size(vec_dot_type); + int64_t ne10_block_start = (ith * ne10/bs) / nth; + int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth; + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0), + (ne10_block_end - ne10_block_start) * bs); + } + } + } + #endif } if (ith == 0) { @@ -7593,7 +7607,6 @@ UseGgmlGemm2:; if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) { num_rows_per_vec_dot = 1; } - ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); if (nth >= nchunk0 * nchunk1) { @@ -7606,6 +7619,84 @@ UseGgmlGemm2:; // ggml_compute_forward_mul_mat_id +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)] + +struct mmid_row_mapping { + int32_t i1; + int32_t i2; +}; + +static void ggml_compute_forward_mul_mat_id_one_chunk( + struct ggml_tensor * dst, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * ids, + const int64_t cur_a, + const int64_t ir0_start, + const int64_t ir0_end, + const int64_t ir1_start, + const int64_t ir1_end, + const char * src0_cur, + const struct mmid_row_mapping * matrix_rows, + const size_t row_size, + const bool src1_cont, + const void * wdata) { + + GGML_TENSOR_BINARY_OP_LOCALS + + const enum ggml_type type = src0->type; + + ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot; + enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; + + const int64_t blck_0 = 16; + const int64_t blck_1 = 16; + + float tmp[16]; + + for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { + for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { + for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ++ir1) { + const int64_t _i12 = ir1; // logical row index for this expert + + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12); + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char *) wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12*ne11)*row_size + : (i11*nb11 + i12*nb12)); + + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2)); + + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) { + vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1); + } + + memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float)); + } + } + } +} + +static void * incr_ptr_aligned(void ** p, size_t size, size_t align) { + + void * ptr = *p; + ptr = (void *) GGML_PAD((uintptr_t) ptr, align); + *p = (void *) ((char *) ptr + size); + return ptr; +} + static void ggml_compute_forward_mul_mat_id( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -7623,7 +7714,6 @@ static void ggml_compute_forward_mul_mat_id( const bool src1_cont = ggml_is_contiguous(src1); - ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot; enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float; @@ -7641,21 +7731,27 @@ static void ggml_compute_forward_mul_mat_id( const int n_ids = ids->ne[0]; // n_expert_used const int n_as = ne02; // n_expert - char * wdata_src1_end = (src1->type == vec_dot_type) ? - (char *) params->wdata : - (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t)); + void * wdata_cur = params->wdata; - struct mmid_row_mapping { - int32_t i1; - int32_t i2; - }; + if (src1->type != vec_dot_type) { + incr_ptr_aligned(&wdata_cur, ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t)); + } + + int64_t * matrix_row_counts = // [n_as] + incr_ptr_aligned(&wdata_cur, n_as*sizeof(int64_t), sizeof(int64_t)); + + struct mmid_row_mapping * matrix_rows = // [n_as][ids->ne[0]*ids->ne[1]] + incr_ptr_aligned(&wdata_cur, n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping), sizeof(int64_t)); - int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] - struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11] + char (*atomic_current_chunk)[CACHE_LINE_SIZE] = // [n_as] + incr_ptr_aligned(&wdata_cur, CACHE_LINE_SIZE * n_as, CACHE_LINE_SIZE); + + GGML_ASSERT(params->wsize >= (size_t)((char *) wdata_cur - (char *) params->wdata)); if (src1->type != vec_dot_type) { char * wdata = params->wdata; + const size_t nbw0 = ggml_type_size(vec_dot_type); const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); const size_t nbw2 = nbw1*ne11; const size_t nbw3 = nbw2*ne12; @@ -7663,19 +7759,32 @@ static void ggml_compute_forward_mul_mat_id( assert(params->wsize >= ne13*nbw3); GGML_ASSERT(src1->type == GGML_TYPE_F32); +#if 0 for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = ith; i11 < ne11; i11 += nth) { + for (int64_t i12 = ith; i12 < ne12; i12 += nth) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), ne10); } } } +#else + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + size_t bs = ggml_blck_size(vec_dot_type); + int64_t ne10_block_start = (ith * ne10/bs) / nth; + int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth; + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0), + (ne10_block_end - ne10_block_start) * bs); + } + } + } +#endif } -#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] - if (ith == 0) { // initialize matrix_row_counts memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); @@ -7693,9 +7802,14 @@ static void ggml_compute_forward_mul_mat_id( } } + // reset current_chunk + for (int cur_a = ith; cur_a < n_as; cur_a += nth) { + atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a); + *current_chunk_ctr = nth; + } + ggml_barrier(params->threadpool); - // compute each matrix multiplication in sequence for (int cur_a = 0; cur_a < n_as; ++cur_a) { const int64_t cne1 = matrix_row_counts[cur_a]; @@ -7703,84 +7817,64 @@ static void ggml_compute_forward_mul_mat_id( continue; } - const char * src0_cur = (const char *) src0->data + cur_a*nb02; - - const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const char * src0_cur = (const char *) src0->data + cur_a * nb02; + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); - const int64_t nr0 = ne01; // src0 rows - const int64_t nr1 = cne1; // src1 rows - - // distribute the thread work across the inner or outer loop based on which one is larger - - const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows - const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows - - const int64_t ith0 = ith % nth0; - const int64_t ith1 = ith / nth0; - - const int64_t dr0 = (nr0 + nth0 - 1)/nth0; - const int64_t dr1 = (nr1 + nth1 - 1)/nth1; - - const int64_t ir010 = dr0*ith0; - const int64_t ir011 = MIN(ir010 + dr0, nr0); + const int64_t nr0 = ne01; + const int64_t nr1 = cne1; - const int64_t ir110 = dr1*ith1; - const int64_t ir111 = MIN(ir110 + dr1, nr1); - - // threads with no work simply yield (not sure if it helps) - //if (ir010 >= ir011 || ir110 >= ir111) { - // sched_yield(); - // continue; - //} + int chunk_size = 16; + if (nr0 == 1 || nr1 == 1) { + chunk_size = 64; + } - // block-tiling attempt - const int64_t blck_0 = 16; - const int64_t blck_1 = 16; +#if defined(__aarch64__) + // disable for ARM + const bool disable_chunking = true; +#else + // disable for NUMA + const bool disable_chunking = ggml_is_numa(); +#endif // defined(__aarch64__) - // attempt to reduce false-sharing (does not seem to make a difference) - float tmp[16]; + int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size; + int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; - for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { - for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { - for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { - const int64_t _i12 = ir1; // logical row index for this expert + if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) { + nchunk0 = nr0 > nr1 ? nth : 1; + nchunk1 = nr0 > nr1 ? 1 : nth; + } - struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12); - const int id = row_mapping.i1; // selected expert index + const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; + const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; - const int64_t i11 = id % ne11; - const int64_t i12 = row_mapping.i2; // row index in src1 + int current_chunk = ith; - const int64_t i1 = id; // selected expert index - const int64_t i2 = i12; // row + atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a); - // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides - // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using - // the original src1 data pointer, so we should index using the indices directly - // TODO: this is a bit of a hack, we should probably have a better way to handle this - const char * src1_col = (const char *) wdata + - (src1_cont || src1->type != vec_dot_type - ? (i11 + i12*ne11)*row_size - : (i11*nb11 + i12*nb12)); + while (current_chunk < nchunk0 * nchunk1) { + const int64_t ith0 = current_chunk % nchunk0; + const int64_t ith1 = current_chunk / nchunk0; - float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2)); + const int64_t ir0_start = dr0 * ith0; + const int64_t ir0_end = MIN(ir0_start + dr0, nr0); - //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { - // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); - //} + const int64_t ir1_start = dr1 * ith1; + const int64_t ir1_end = MIN(ir1_start + dr1, nr1); - for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { - vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1); - } + ggml_compute_forward_mul_mat_id_one_chunk( + dst, src0, src1, ids, cur_a, + ir0_start, ir0_end, ir1_start, ir1_end, + src0_cur, matrix_rows, row_size, src1_cont, wdata + ); - memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); - } + if (nth >= nchunk0 * nchunk1) { + break; } + + current_chunk = atomic_fetch_add_explicit(current_chunk_ctr, 1, memory_order_relaxed); } } - -#undef MMID_MATRIX_ROW } // ggml_compute_forward_out_prod @@ -13713,14 +13807,19 @@ struct ggml_cplan ggml_graph_plan( cur = 0; const struct ggml_tensor * src0 = node->src[0]; const struct ggml_tensor * src1 = node->src[1]; + const struct ggml_tensor * ids = node->src[2]; const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type; + const int n_as = src0->ne[2]; + // src1 if (src1->type != vec_dot_type) { - cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)); + cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)) + sizeof(int64_t); } - const int n_as = src0->ne[2]; - cur += GGML_PAD(cur, sizeof(int64_t)); // align - cur += n_as * sizeof(int64_t); // matrix_row_counts - cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows + // matrix_row_counts + cur += n_as * sizeof(int64_t) + sizeof(int64_t); + // matrix_rows + cur += n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping) + sizeof(int64_t); + // atomic_current_chunk + cur += CACHE_LINE_SIZE*n_as + CACHE_LINE_SIZE; } break; case GGML_OP_OUT_PROD: { From 07bbd8e7eadf56063008713a312051b710358c30 Mon Sep 17 00:00:00 2001 From: R0CKSTAR Date: Thu, 13 Feb 2025 20:28:18 +0800 Subject: [PATCH 32/58] musa: bump MUSA SDK version to rc3.1.1 (llama/11822) * musa: Update MUSA SDK version to rc3.1.1 Signed-off-by: Xiaodong Ye * musa: Remove workaround in PR #10042 Signed-off-by: Xiaodong Ye --------- Signed-off-by: Xiaodong Ye --- ggml/src/ggml-cuda/ggml-cuda.cu | 5 ----- 1 file changed, 5 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 6ea41577768..093ad70991b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1480,12 +1480,7 @@ static void ggml_cuda_op_mul_mat( const size_t nbytes_data = ggml_nbytes(src0); const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding); - // TODO: remove this for MUSA once the Guilty Lockup issue is resolved -#ifndef GGML_USE_MUSA CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream)); -#else // GGML_USE_MUSA - CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream)); -#endif // !GGML_USE_MUSA } // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared: From f7244e022b5216bd5ff5c847f350ab41e0fa6257 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Thu, 13 Feb 2025 09:05:04 -0800 Subject: [PATCH 33/58] llamafile: use member variable instead of constant for iq4nlt (llama/11780) --- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index c22a662876c..e0482c59377 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -280,14 +280,6 @@ template <> inline __m256bh load(const float *p) { } #endif -//////////////////////////////////////////////////////////////////////////////////////////////////// -// CONSTANTS - -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) -static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; -static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl); -#endif - //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION @@ -614,6 +606,14 @@ class tinyBLAS_Q0_AVX { TC *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + const int8_t kvalues_iq4nl[16] = { + -127, -104, -83, -65, + -49, -35, -22, -10, + 1, 13, 25, 38, + 53, 69, 89, 113 + }; + + iq4nlt = _mm_loadu_si128((const __m128i *)kvalues_iq4nl); } void matmul(int64_t m, int64_t n) { @@ -1038,6 +1038,7 @@ class tinyBLAS_Q0_AVX { const int64_t ldc; const int ith; const int nth; + __m128i iq4nlt; }; #endif // __AVX__ From 0fc51b7bc48e087799a5266c5db489d46dc556dd Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Fri, 14 Feb 2025 02:59:40 +0000 Subject: [PATCH 34/58] vulkan: linux builds + small subgroup size fixes (llama/11767) * mm subgroup size * upload vulkan x86 builds --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index bffe95086af..99d50afda2d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1430,6 +1430,7 @@ static void ggml_vk_load_shaders(vk_device& device) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); // some shaders have a minimum subgroup size + const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u); const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); @@ -1492,13 +1493,13 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1; const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1; - l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; - m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; - s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; + l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; - l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; - m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; - s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; + l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; From 74a285ad0e67b001a9c49614bbd6be5f0864521c Mon Sep 17 00:00:00 2001 From: Jinyang He Date: Fri, 14 Feb 2025 16:54:27 +0800 Subject: [PATCH 35/58] ggml: optimize some vec dot functions for LoongArch ASX (llama/11842) * Optimize ggml_vec_dot_q3_K_q8_K for LoongArch ASX * Optimize ggml_vec_dot_q4_K_q8_K for LoongArch ASX * Optimize ggml_vec_dot_q6_K_q8_K for LoongArch ASX * Optimize ggml_vec_dot_q5_K_q8_K for LoongArch ASX * Optimize ggml_vec_dot_q2_K_q8_K for LoongArch ASX * Optimize mul_sum_i8_pairs_float for LoongArch ASX * Optimize ggml_vec_dot_iq4_xs_q8_K for LoongArch ASX --- ggml/src/ggml-cpu/ggml-cpu-quants.c | 362 +++++++++++----------------- 1 file changed, 141 insertions(+), 221 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 1b4bd66e80c..0315dc2575e 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -562,6 +562,41 @@ static __m256i lasx_packs_h(__m256i a, __m256i b) { return __lasx_xvpickev_b(tmp1, tmp); } +static inline __m256i lasx_madd_h_b(__m256i a, __m256i b) { + __m256i tmp1, tmp2; + tmp1 = __lasx_xvmulwev_h_b(a, b); + tmp2 = __lasx_xvmulwod_h_b(a, b); + return __lasx_xvadd_h(tmp1, tmp2); +} + +static inline __m256i lasx_xvrepl128vei_h(__m256i a, const unsigned int b) { + switch (b) { + case 0: return __lasx_xvrepl128vei_h(a, 0); + case 1: return __lasx_xvrepl128vei_h(a, 1); + case 2: return __lasx_xvrepl128vei_h(a, 2); + case 3: return __lasx_xvrepl128vei_h(a, 3); + case 4: return __lasx_xvrepl128vei_h(a, 4); + case 5: return __lasx_xvrepl128vei_h(a, 5); + case 6: return __lasx_xvrepl128vei_h(a, 6); + case 7: return __lasx_xvrepl128vei_h(a, 7); + default: __builtin_unreachable(); + } +} + +static inline __m256i lasx_xvandi_b_bit(__m256i a, const unsigned int b) { + switch (b) { + case 0: return __lasx_xvandi_b(a, 1 << 0); + case 1: return __lasx_xvandi_b(a, 1 << 1); + case 2: return __lasx_xvandi_b(a, 1 << 2); + case 3: return __lasx_xvandi_b(a, 1 << 3); + case 4: return __lasx_xvandi_b(a, 1 << 4); + case 5: return __lasx_xvandi_b(a, 1 << 5); + case 6: return __lasx_xvandi_b(a, 1 << 6); + case 7: return __lasx_xvandi_b(a, 1 << 7); + default: __builtin_unreachable(); + } +} + // multiply int8_t, add results pairwise twice static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { // Get absolute values of x vectors @@ -656,13 +691,8 @@ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) // multiply int8_t, add results pairwise twice and return as float vector static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { - - // Get absolute values of x vectors - const __m256i ax = __lasx_xvsigncov_b(x, x); - // Sign the values of the y vectors - const __m256i sy = __lasx_xvsigncov_b(x, y); - - return mul_sum_us8_pairs_float(ax, sy); + const __m256i dot = lasx_madd_h_b(x, y); + return sum_i16_pairs_float(dot); } static inline __m128i packNibbles( __m256i bytes ) { @@ -4965,9 +4995,6 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r #elif defined __loongarch_asx - const __m256i m3 = __lasx_xvreplgr2vr_b(3); - const __m128i m4 = __lsx_vreplgr2vr_b(0xF); - __m256 acc = (__m256)__lasx_xvldi(0); for (int i = 0; i < nb; ++i) { @@ -4978,18 +5005,15 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0); - const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4); - const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4); - const __m256i mins = lasx_ext8_16(mins8); + const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0); + const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf); + const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4)); const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0)); acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc); - const __m256i all_scales = lasx_ext8_16(scales8); - const __m128i l_scales = lasx_extracti128(all_scales, 0); - const __m128i h_scales = lasx_extracti128(all_scales, 1); - const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; + const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask)); __m256i sumi = __lasx_xvldi(0); @@ -5002,20 +5026,20 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q2_0 = __lasx_xvand_v(q2bits, m3); - const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 2), m3); - const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 4), m3); - const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 6), m3); + const __m256i q2_0 = __lasx_xvandi_b(q2bits, 3); + const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 2), 3); + const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 4), 3); + const __m256i q2_3 = __lasx_xvsrli_b(q2bits, 6); - __m256i p0 = lasx_maddubs_h(q2_0, q8_0); - __m256i p1 = lasx_maddubs_h(q2_1, q8_1); - __m256i p2 = lasx_maddubs_h(q2_2, q8_2); - __m256i p3 = lasx_maddubs_h(q2_3, q8_3); + __m256i p0 = lasx_madd_h_b(q2_0, q8_0); + __m256i p1 = lasx_madd_h_b(q2_1, q8_1); + __m256i p2 = lasx_madd_h_b(q2_2, q8_2); + __m256i p3 = lasx_madd_h_b(q2_3, q8_3); - p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0)), p0); - p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1)), p1); - p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2)), p2); - p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3)), p3); + p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p0); + p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p1); + p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p2); + p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p3); p0 = __lasx_xvadd_w(p0, p1); p2 = __lasx_xvadd_w(p2, p3); @@ -5771,8 +5795,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r #elif defined __loongarch_asx - const __m256i m3 = __lasx_xvreplgr2vr_b(3); - const __m256i mone = __lasx_xvreplgr2vr_b(1); const __m128i m32 = __lsx_vreplgr2vr_b(32); __m256 acc = (__m256)__lasx_xvldi(0); @@ -5792,10 +5814,9 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); scales128 = __lsx_vsub_b(scales128, m32); - const __m256i all_scales = lasx_ext8_16(scales128); - const __m128i l_scales = lasx_extracti128(all_scales, 0); - const __m128i h_scales = lasx_extracti128(all_scales, 1); - const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; + + const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask)); // high bit const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0); @@ -5803,35 +5824,23 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r // integer accumulator __m256i sumi = __lasx_xvldi(0); - int bit = 0; - int is = 0; - __m256i xvbit; - - for (int j = 0; j < QK_K/128; ++j) { // load low 2 bits const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32; - xvbit = __lasx_xvreplgr2vr_h(bit); // prepare low and high bits - const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3); - const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); - ++bit; - - xvbit = __lasx_xvreplgr2vr_h(bit); - const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3); - const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); - ++bit; - - xvbit = __lasx_xvreplgr2vr_h(bit); - const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3); - const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); - ++bit; - - xvbit = __lasx_xvreplgr2vr_h(bit); - const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3); - const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); - ++bit; + const __m256i q3l_0 = __lasx_xvandi_b(q3bits, 3); + const __m256i q3l_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 2), 3); + const __m256i q3l_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 4), 3); + const __m256i q3l_3 = __lasx_xvsrli_b(q3bits, 6); + const __m256i q3h_0 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 0), 0), 2); + const __m256i q3h_1 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 1), 0), 2); + const __m256i q3h_2 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 2), 0), 2); + const __m256i q3h_3 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 3), 0), 2); + const __m256i q3_0 = __lasx_xvor_v(q3h_0, q3l_0); + const __m256i q3_1 = __lasx_xvor_v(q3h_1, q3l_1); + const __m256i q3_2 = __lasx_xvor_v(q3h_2, q3l_2); + const __m256i q3_3 = __lasx_xvor_v(q3h_3, q3l_3); // load Q8 quants const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; @@ -5839,29 +5848,16 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0); - __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1); - __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2); - __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3); - - __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0); - __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1); - __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2); - __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3); - - p16_0 = __lasx_xvsub_h(p16_0, q8s_0); - p16_1 = __lasx_xvsub_h(p16_1, q8s_1); - p16_2 = __lasx_xvsub_h(p16_2, q8s_2); - p16_3 = __lasx_xvsub_h(p16_3, q8s_3); + __m256i p16_0 = lasx_madd_h_b(q8_0, q3_0); + __m256i p16_1 = lasx_madd_h_b(q8_1, q3_1); + __m256i p16_2 = lasx_madd_h_b(q8_2, q3_2); + __m256i p16_3 = lasx_madd_h_b(q8_3, q3_3); // multiply with scales - p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); - p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); - p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); - p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0); + p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1); + p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2); + p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3); // accumulate p16_0 = __lasx_xvadd_w(p16_0, p16_1); @@ -5869,7 +5865,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2)); } // multiply with block scale and accumulate - acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME + acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); } *s = hsum_float_8(acc); @@ -6562,11 +6558,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = vec_extract(vsumf0, 0); #elif defined __loongarch_asx - GGML_UNUSED(kmask1); - GGML_UNUSED(kmask2); - GGML_UNUSED(kmask3); - - const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); __m256 acc = (__m256)__lasx_xvldi(0); __m128 acc_m = (__m128)__lsx_vldi(0); @@ -6586,33 +6577,34 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0])); + const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128); + const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0); const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); - const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s); + const __m128i prod = lsx_madd_h(mins128, q8s); acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m); - const __m128i sc128 = lasx_extracti128(mins_and_scales, 0); - const __m256i scales = lasx_insertf128(sc128, sc128); + const __m256i scales = lasx_insertf128(scales128, scales128); __m256i sumi = __lasx_xvldi(0); for (int j = 0; j < QK_K/64; ++j) { - const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1)); + const __m256i scale_l = lasx_xvrepl128vei_h(scales, 2 * j + 0); + const __m256i scale_h = lasx_xvrepl128vei_h(scales, 2 * j + 1); const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; - const __m256i q4l = __lasx_xvand_v(q4bits, m4); - const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4); + const __m256i q4l = __lasx_xvandi_b(q4bits, 0xf); + const __m256i q4h = __lasx_xvsrli_b(q4bits, 4); const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - __m256i p16l = lasx_maddubs_h(q4l, q8l); + __m256i p16l = lasx_madd_h_b(q4l, q8l); p16l = lasx_madd_h(scale_l, p16l); const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - __m256i p16h = lasx_maddubs_h(q4h, q8h); + __m256i p16h = lasx_madd_h_b(q4h, q8h); p16h = lasx_madd_h(scale_h, p16h); const __m256i sumj = __lasx_xvadd_w(p16l, p16h); @@ -7289,19 +7281,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = vec_extract(vsumf0, 0); #elif defined __loongarch_asx - GGML_UNUSED(kmask1); - GGML_UNUSED(kmask2); - GGML_UNUSED(kmask3); - - const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); - const __m128i mzero = __lsx_vldi(0); - const __m256i mone = __lasx_xvreplgr2vr_b(1); __m256 acc = (__m256)__lasx_xvldi(0); + __m128 acc_m = (__m128)__lsx_vldi(0); - float summs = 0.f; - - for (int i = 0; i < nb; ++i) { + for (int i = 0; i < nb; ++i) { const uint8_t * restrict q5 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -7316,49 +7300,40 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r utmp[2] = uaux; utmp[0] &= kmask1; - const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0])); + const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128); + const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0); const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); - const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s); - const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero); - summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check + const __m128i prod = lsx_madd_h(mins128, q8s); + acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m); - const __m128i sc128 = lasx_extracti128(mins_and_scales, 0); - const __m256i scales = lasx_insertf128(sc128, sc128); + const __m256i scales = lasx_insertf128(scales128, scales128); const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0); - __m256i hmask = mone; __m256i sumi = __lasx_xvldi(0); - int bit = 0; - __m256i xvbit; - for (int j = 0; j < QK_K/64; ++j) { - const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1)); + const __m256i scale_0 = lasx_xvrepl128vei_h(scales, 2 * j + 0); + const __m256i scale_1 = lasx_xvrepl128vei_h(scales, 2 * j + 1); const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32; - xvbit = __lasx_xvreplgr2vr_h(bit++); - const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4); - const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); - const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0); - hmask = __lasx_xvslli_h(hmask, 1); - - xvbit = __lasx_xvreplgr2vr_h(bit++); - const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4); - const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); - const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1); - hmask = __lasx_xvslli_h(hmask, 1); + const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf); + const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4); + const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef); + const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef); + const __m256i q5_0 = __lasx_xvor_v(q5l_0, q5h_0); + const __m256i q5_1 = __lasx_xvor_v(q5l_1, q5h_1); const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0); - __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1); + __m256i p16_0 = lasx_madd_h_b(q5_0, q8_0); + __m256i p16_1 = lasx_madd_h_b(q5_1, q8_1); p16_0 = lasx_madd_h(scale_0, p16_0); p16_1 = lasx_madd_h(scale_1, p16_1); @@ -7372,7 +7347,10 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r } - *s = hsum_float_8(acc) + summs; + acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8)); + acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4)); + + *s = hsum_float_8(acc) + ((v4f32)acc_m)[0]; #else @@ -8033,8 +8011,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r #elif defined __loongarch_asx - const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); - const __m256i m2 = __lasx_xvreplgr2vr_b(3); const __m256i m32s = __lasx_xvreplgr2vr_b(32); __m256 acc = (__m256)__lasx_xvldi(0); @@ -8047,58 +8023,42 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; - const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0); + const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0); + const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask)); __m256i sumi = __lasx_xvldi(0); - int is = 0; - for (int j = 0; j < QK_K/128; ++j) { - const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0)); - const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1)); - const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2)); - const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3)); - is += 4; - const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32; - const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4); - const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4); - const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4); - const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4); + const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3), 4); + const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3 << 2), 2); + const __m256i q4h_2 = __lasx_xvandi_b(q4bitsH, 3 << 4); + const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b(q4bitsH, 3 << 6), 2); - const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0); - const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1); - const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2); - const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3); + const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b(q4bits1, 0xf), q4h_0); + const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b(q4bits2, 0xf), q4h_1); + const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits1, 4), q4h_2); + const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits2, 4), q4h_3); const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0); - __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1); - __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2); - __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3); - - __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0); - __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1); - __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2); - __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3); + __m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0); + __m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1); + __m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2); + __m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3); - p16_0 = __lasx_xvsub_h(p16_0, q8s_0); - p16_1 = __lasx_xvsub_h(p16_1, q8s_1); - p16_2 = __lasx_xvsub_h(p16_2, q8s_2); - p16_3 = __lasx_xvsub_h(p16_3, q8s_3); - - p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0); - p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1); - p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2); - p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3); + p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0); + p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1); + p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2); + p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3); sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3)); @@ -10423,13 +10383,9 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { } #elif defined(__loongarch_asx) static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { - const __m256i ax = __lasx_xvsigncov_b(x, x); - const __m256i sy = __lasx_xvsigncov_b(x, y); - __m256i tmp1, tmp2, tmp3; - tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy); - tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy); - tmp3 = __lasx_xvadd_h(tmp1, tmp2); - return __lasx_xvsat_h(tmp3, 15); + const __m256i a = __lasx_xvmulwev_h_b(x, y); + const __m256i b = __lasx_xvmulwod_h_b(x, y); + return __lasx_xvadd_h(a, b); } #endif @@ -11479,67 +11435,31 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * #elif defined(__loongarch_asx) const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); - const __m128i m4b = __lsx_vreplgr2vr_b(0x0f); __m256 accum = (__m256)__lasx_xvldi(0); - __m256i tmp1; - __m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask; - mask_8f = __lsx_vreplgr2vr_b(0x8f); for (int ibl = 0; ibl < nb; ++ibl) { const uint8_t * qs = x[ibl].qs; const int8_t * q8 = y[ibl].qs; uint16_t sh = x[ibl].scales_h; __m256i sumi1 = __lasx_xvldi(0); __m256i sumi2 = __lasx_xvldi(0); - __m128i zero = __lsx_vldi(0); for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16; - const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16; + const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16; + const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16; const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b), mask_8f); - tmp0 = __lsx_vori_b(tmp2, 0x10); - mask = __lsx_vsle_b(zero, tmp2); - tmp3 = __lsx_vand_v(tmp0, mask); - tmp3 = __lsx_vshuf_b(values128, zero, tmp3); - - tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f); - tmp0 = __lsx_vori_b(tmp2, 0x10); - mask = __lsx_vsle_b(zero, tmp2); - tmp4 = __lsx_vand_v(tmp0, mask); - tmp4 = __lsx_vshuf_b(values128, zero, tmp4); - - const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4); - - tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f); - tmp0 = __lsx_vori_b(tmp2, 0x10); - mask = __lsx_vsle_b(zero, tmp2); - tmp3 = __lsx_vand_v(tmp0, mask); - tmp3 = __lsx_vshuf_b(values128, zero, tmp3); - - tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f); - tmp0 = __lsx_vori_b(tmp2, 0x10); - mask = __lsx_vsle_b(zero, tmp2); - tmp4 = __lsx_vand_v(tmp0, mask); - tmp4 = __lsx_vshuf_b(values128, zero, tmp4); - - const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4); - + const __m256i q4b_1 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_1, 4)), + __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_1, 0xf))); + const __m256i q4b_2 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_2, 4)), + __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_2, 0xf))); const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; sh >>= 4; - __m256i tmp5, tmp6; - tmp1 = __lasx_xvreplgr2vr_h(ls1); - tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1); - tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1); - const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6); - tmp1 = __lasx_xvreplgr2vr_h(ls2); - tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1); - tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1); - const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6); + const __m256i p_1 = lasx_madd_h(p16_1, __lasx_xvreplgr2vr_h(ls1)); + const __m256i p_2 = lasx_madd_h(p16_2, __lasx_xvreplgr2vr_h(ls2)); sumi1 = __lasx_xvadd_w(p_1, sumi1); sumi2 = __lasx_xvadd_w(p_2, sumi2); } From 14b9afe17f0af9f11e7e8c99b9f1b6a9207841cc Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Fri, 14 Feb 2025 15:33:52 +0100 Subject: [PATCH 36/58] cuda : add ampere to the list of default architectures (llama/11870) --- ggml/src/ggml-cuda/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 119fd39b8e4..682640b5208 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -15,9 +15,9 @@ if (CUDAToolkit_FOUND) if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") set(CMAKE_CUDA_ARCHITECTURES "native") elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) - set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75") + set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80") else() - set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75") + set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75;80") endif() endif() message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") From 3151f2c0b008a65f25a451b8034ee2051ea15853 Mon Sep 17 00:00:00 2001 From: lhez Date: Fri, 14 Feb 2025 11:12:23 -0800 Subject: [PATCH 37/58] opencl: Fix rope and softmax (llama/11833) * opencl: fix `ROPE` * opencl: fix `SOFT_MAX` * Add fp16 variant * opencl: enforce subgroup size for `soft_max` --- ggml/src/ggml-opencl/ggml-opencl.cpp | 30 ++++- ggml/src/ggml-opencl/kernels/ggml-opencl.cl | 138 ++++++++++++++++++++ 2 files changed, 164 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index ed90e471ac0..7a0f94cf24c 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -143,6 +143,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_rms_norm; cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; cl_kernel kernel_soft_max, kernel_soft_max_4; + cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0; cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32; @@ -614,6 +615,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { CL_CHECK((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf_8", &err), err)); CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program, "kernel_soft_max", &err), err)); CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4", &err), err)); + CL_CHECK((backend_ctx->kernel_soft_max_f16 = clCreateKernel(backend_ctx->program, "kernel_soft_max_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4_f16", &err), err)); CL_CHECK((backend_ctx->kernel_rope_norm_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f32", &err), err)); CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err)); CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err)); @@ -1044,8 +1047,16 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return true; case GGML_OP_DIAG_MASK_INF: return op->ne[3] == 1; - case GGML_OP_ROPE: + case GGML_OP_ROPE: { + const int mode = ((const int32_t *) op->op_params)[2]; + if (mode & GGML_ROPE_TYPE_MROPE) { + return false; + } + if (mode & GGML_ROPE_TYPE_VISION) { + return false; + } return true; + } default: return false; } @@ -3666,6 +3677,8 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + // Local size must be wave size. Each workgroup is a wave, working on a row, // where a row corresponds to leading dimension. int nth = MIN(32, ne00); @@ -3683,9 +3696,17 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c cl_kernel kernel; if (ne00%4 == 0) { - kernel = backend_ctx->kernel_soft_max_4; + if (use_f16) { + kernel = backend_ctx->kernel_soft_max_4_f16; + } else { + kernel = backend_ctx->kernel_soft_max_4; + } } else { - kernel = backend_ctx->kernel_soft_max; + if (use_f16) { + kernel = backend_ctx->kernel_soft_max_f16; + } else { + kernel = backend_ctx->kernel_soft_max; + } } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); @@ -3766,7 +3787,8 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const const int nb2 = dst ? dst->nb[2] : 0; const int nb3 = dst ? dst->nb[3] : 0; - GGML_ASSERT(ne10 == ne02); + GGML_ASSERT(ne10 % ne02 == 0); + GGML_ASSERT(ne10 >= ne02); int nth = MIN(64, ne00); diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl index d1cdf709bab..d3cfb2f91e1 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl @@ -679,6 +679,9 @@ kernel void kernel_diag_mask_inf_8( //------------------------------------------------------------------------------ // softmax //------------------------------------------------------------------------------ +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif kernel void kernel_soft_max( global float * src0, ulong offset0, @@ -811,6 +814,141 @@ kernel void kernel_soft_max_4( } } +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_f16( + global float * src0, + ulong offset0, + global half * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float *)((global char *)src0 + offset0); + src1 = (global half *)((global char *)src1 + offset1); + dst = (global float *)((global char *)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0; + global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float max = sub_group_reduce_max(lmax); + + // parallel sum + float lsum = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // wish to compute it twice. + pdst[i00] = exp_psrc0; + } + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + pdst[i00] /= sum; + } +} + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_4_f16( + global float * src0, + ulong offset0, + global half * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float *)((global char *)src0 + offset0); + src1 = (global half *)((global char *)src1 + offset1); + dst = (global float *)((global char *)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0; + global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)); + } + float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); + + const float max = sub_group_reduce_max(lmax); + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)) - max); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + pdst4[i00] /= sum; + } +} + //------------------------------------------------------------------------------ // kernel_rope //------------------------------------------------------------------------------ From bcc9bef13368bf76350d4aae343ecba1a2fd6b6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9my=20O?= Date: Sat, 15 Feb 2025 09:01:40 +0100 Subject: [PATCH 38/58] vulkan: initial support for IQ1_S and IQ1_M quantizations (llama/11528) * vulkan: initial support for IQ1_S and IQ1_M quantizations * vulkan: define MMV kernels for IQ1 quantizations * devops: increase timeout of Vulkan tests again * vulkan: simplify ifdef for init_iq_shmem --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 94 +++++++-- .../vulkan-shaders/copy_from_quant.comp | 2 +- .../vulkan-shaders/copy_to_quant.comp | 2 +- .../vulkan-shaders/dequant_funcs.comp | 88 ++++++++- .../vulkan-shaders/dequant_funcs_cm2.comp | 54 +++++ .../vulkan-shaders/dequant_iq1_m.comp | 42 ++++ .../vulkan-shaders/dequant_iq1_s.comp | 35 ++++ .../vulkan-shaders/flash_attn_cm2.comp | 2 +- .../vulkan-shaders/get_rows_quant.comp | 2 +- .../vulkan-shaders/mul_mat_vec.comp | 2 +- .../vulkan-shaders/mul_mat_vec_iq1_m.comp | 82 ++++++++ .../vulkan-shaders/mul_mat_vec_iq1_s.comp | 79 ++++++++ .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 55 +++++- .../vulkan-shaders/mul_mm_cm2.comp | 2 +- .../src/ggml-vulkan/vulkan-shaders/types.comp | 187 ++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 11 +- 16 files changed, 710 insertions(+), 29 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 99d50afda2d..68f2ea14baf 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1385,6 +1385,10 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec uint32_t lut_size = 0; switch (src0_type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + lut_size = 2*2048; + break; case GGML_TYPE_IQ2_XXS: lut_size = 8*256; break; @@ -1623,6 +1627,8 @@ static void ggml_vk_load_shaders(vk_device& device) { //CREATE_FA(GGML_TYPE_Q4_K, q4_k) //CREATE_FA(GGML_TYPE_Q5_K, q5_k) //CREATE_FA(GGML_TYPE_Q6_K, q6_k) + //CREATE_FA(GGML_TYPE_IQ1_S, iq1_s) + //CREATE_FA(GGML_TYPE_IQ1_M, iq1_m) //CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs) //CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs) //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s) @@ -1657,6 +1663,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) @@ -1676,6 +1684,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) @@ -1730,6 +1740,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -1749,6 +1761,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -1774,13 +1788,15 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } else { CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); @@ -1793,13 +1809,15 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } #undef CREATE_MM2 #undef CREATE_MM @@ -1842,6 +1860,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -1865,6 +1885,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); @@ -1906,13 +1928,15 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); @@ -1929,13 +1953,15 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); #undef CREATE_MM } @@ -1965,6 +1991,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); @@ -1985,6 +2013,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); @@ -2006,6 +2036,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); @@ -2026,6 +2058,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_S], "dequant_iq1_s", dequant_iq1_s_len, dequant_iq1_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_M], "dequant_iq1_m", dequant_iq1_m_len, dequant_iq1_m_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS], "dequant_iq2_xs", dequant_iq2_xs_len, dequant_iq2_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); @@ -2042,6 +2076,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S], "get_rows_iq1_s", get_rows_iq1_s_len, get_rows_iq1_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M], "get_rows_iq1_m", get_rows_iq1_m_len, get_rows_iq1_m_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs", get_rows_iq2_xs_len, get_rows_iq2_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -2057,6 +2093,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S], "get_rows_iq1_s_f32", get_rows_iq1_s_f32_len, get_rows_iq1_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M], "get_rows_iq1_m_f32", get_rows_iq1_m_f32_len, get_rows_iq1_m_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs_f32", get_rows_iq2_xs_f32_len, get_rows_iq2_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -3009,6 +3047,8 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: @@ -3063,6 +3103,8 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: @@ -3100,6 +3142,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: @@ -3149,6 +3193,8 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: @@ -3181,6 +3227,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: @@ -8057,6 +8105,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: @@ -8131,6 +8181,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm //case GGML_TYPE_Q4_K: //case GGML_TYPE_Q5_K: //case GGML_TYPE_Q6_K: + //case GGML_TYPE_IQ1_S: + //case GGML_TYPE_IQ1_M: //case GGML_TYPE_IQ2_XXS: //case GGML_TYPE_IQ2_XS: //case GGML_TYPE_IQ2_S: @@ -8154,6 +8206,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp index 9c9fe9626db..dbc7daa3328 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -12,7 +12,7 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; #endif void main() { -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); if (gl_LocalInvocationIndex.x != 0) { return; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index 660811086d6..c813f14044e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -217,7 +217,7 @@ void quantize(uint dst_idx, uint src_idx) #endif void main() { -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); if (gl_LocalInvocationIndex.x != 0) { return; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index ecfdbfaa88c..10318e87660 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -88,6 +88,83 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_IQ1_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + const int i8 = int(iqs % 8); + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const float dl = float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint idxhi = bitfieldExtract(qh, 3 * int(ib8 & 3), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]); + // Signed bitfield extract. + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + return dl * (vec2(gvec) + delta); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + const int i8 = int(iqs % 8); + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const float dl = 2 * bitfieldExtract(qh, 12, 3) + 1; + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); + // Signed bitfield extract. + const ivec4 gvec = ivec4( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2), + bitfieldExtract(grid, 2 * (i8 + 2), 2), + bitfieldExtract(grid, 2 * (i8 + 3), 2) + ); + return dl * (vec4(gvec) + delta); +} +#endif + +#if defined(DATA_A_IQ1_M) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib8 = iqs / 8; + const uint ib16 = iqs / 16; + const int i8 = int(iqs % 8); + const uint sc = data_a[a_offset + ib].scales[iqs / 64]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + // Signed bitfield extract. + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + return dl * (vec2(gvec) + delta); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib8 = iqs / 8; + const uint ib16 = iqs / 16; + const int i8 = int(iqs % 8); + const uint sc = data_a[a_offset + ib].scales[iqs / 64]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + // Signed bitfield extract. + const ivec4 gvec = ivec4( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2), + bitfieldExtract(grid, 2 * (i8 + 2), 2), + bitfieldExtract(grid, 2 * (i8 + 3), 2) + ); + return dl * (vec4(gvec) + delta); +} +#endif + #if defined(DATA_A_IQ2_XXS) vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint ib32 = iqs / 32; @@ -357,7 +434,16 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#if defined(DATA_A_IQ1_M) +vec2 get_dm(uint ib, uint a_offset) { + const uint16_t[4] scales = data_a[a_offset + ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + return vec2(d, 0); +} +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) vec2 get_dm(uint ib, uint a_offset) { return vec2(float(data_a[a_offset + ib].d), 0); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 0eba3742011..4770469eddc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -301,6 +301,56 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2 return ret; } +#if defined(DATA_A_IQ1_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S { + block_iq1_s block; +}; + +float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = idx / 32; + const uint ib8 = idx / 8; + + const uint qh = bl.block.qh[ib32]; + const uint qs = bl.block.qs[ib8]; + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]; + + float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta)); + return ret; +} +#endif + +#if defined(DATA_A_IQ1_M) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_M { + block_iq1_m block; +}; + +float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const u16vec4 scales = u16vec4(bl.block.scales[0], bl.block.scales[1], bl.block.scales[2], bl.block.scales[3]) >> 12; + const float16_t d = uint16BitsToHalf(scales.x | (scales.y << 4) | (scales.z << 8) | (scales.w << 12)); + const uint idx = coordInBlock[1]; + + const uint ib8 = idx / 8; + const uint ib16 = idx / 16; + const int i8 = int(idx % 8); + const uint sc = bl.block.scales[ib8 / 8]; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | ((qh & 7) << 8)]; + + float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta)); + return ret; +} +#endif + #if defined(DATA_A_IQ2_XXS) layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS { block_iq2_xxs block; @@ -512,6 +562,10 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor #define dequantFuncA dequantFuncQ5_K #elif defined(DATA_A_Q6_K) #define dequantFuncA dequantFuncQ6_K +#elif defined(DATA_A_IQ1_S) +#define dequantFuncA dequantFuncIQ1_S +#elif defined(DATA_A_IQ1_M) +#define dequantFuncA dequantFuncIQ1_M #elif defined(DATA_A_IQ2_XXS) #define dequantFuncA dequantFuncIQ2_XXS #elif defined(DATA_A_IQ2_XS) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp new file mode 100644 index 00000000000..39184ef5823 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp @@ -0,0 +1,42 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq1_m data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint ib64 = ib32 / 2; + const uint b_idx = 256 * ib + 32 * ib32; + + const uint16_t[4] scales = data_a[ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + + const uint sc = data_a[ib].scales[ib64]; + [[unroll]] for (int l = 0; l < 4; ++l) { + const uint ib16 = 2 * ib32 + l / 2; + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); + const uint qh = data_a[ib].qh[ib16] >> (4 * (l & 1)); + const uint qs = data_a[ib].qs[4 * ib32 + l]; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + [[unroll]] for (int j = 0; j < 8; ++j) { + data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta)); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp new file mode 100644 index 00000000000..fd1e4e30d25 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp @@ -0,0 +1,35 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq1_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + uint qh = data_a[ib].qh[ib32]; + const float d = float(data_a[ib].d); + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qs = data_a[ib].qs[4 * ib32 + l]; + const uint hi = bitfieldExtract(qh, 3 * int(l), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (hi << 8)]); + [[unroll]] for (int j = 0; j < 8; ++j) { + data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta)); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index ba88ce79a21..df30355f635 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -104,7 +104,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele #endif void main() { -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp index c16a2a9f605..c9f855687dc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -12,7 +12,7 @@ void main() { const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index d7e99727db1..31ecd9f81a8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -133,7 +133,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { void main() { const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp new file mode 100644 index 00000000000..e4acbd4f962 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp @@ -0,0 +1,82 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint16_t[4] scales = data_a[ibi].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + + const uint sc = data_a[ibi].scales[ib32 / 2] >> (6 * (ib32 & 1)); + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qh = data_a[ibi].qh[2 * ib32 + l / 2] >> (4 * (l&1)); + const uint qs = data_a[ibi].qs[4 * ib32 + l]; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1); + + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int k = 0; k < 4; ++k) { + sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta, + fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum)); + } + temp[j][n] = fma(dl, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp new file mode 100644 index 00000000000..309da0991ae --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp @@ -0,0 +1,79 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint qh = data_a[ibi].qh[ib32]; + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qs = data_a[ibi].qs[4 * ib32 + l]; + const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int k = 0; k < 4; ++k) { + sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta, + fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum)); + } + temp[j][n] = fma(dl, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 33b2234e71d..39657195cfc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -6,6 +6,9 @@ #ifdef FLOAT16 #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #endif +#if defined(DATA_A_IQ1_M) +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#endif #ifdef COOPMAT #extension GL_KHR_cooperative_matrix : enable @@ -95,7 +98,7 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif void main() { -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif @@ -437,6 +440,56 @@ void main() { buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); +#elif defined(DATA_A_IQ1_S) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint ib8 = (idx % 128) / 4; + const int i8 = 2 * int(idx % 4); + + const float d = float(data_a[ib].d); + const uint qh = data_a[ib].qh[ib32]; + const uint qs = data_a[ib].qs[ib8]; + const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); + + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + const vec2 v = dl * (vec2(gvec) + delta); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ1_M) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib8 = (idx % 128) / 4; + const uint ib16 = ib8 / 2; + const int i8 = 2 * int(idx % 4); + + const uint16_t[4] scales = data_a[ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + const uint sc = scales[ib8 / 8]; + const uint qs = data_a[ib].qs[ib8]; + const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + const vec2 v = dl * (vec2(gvec) + delta); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); #elif defined(DATA_A_IQ2_XXS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 7e29bbfec7b..66dd2c860d8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -106,7 +106,7 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem #endif void main() { -#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index db643a54c8e..dfa16cda516 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -294,6 +294,187 @@ struct block_q6_K_packed16 // IQuants +#define QUANT_K_IQ1_S 256 +#define QUANT_R_IQ1_S 1 + +struct block_iq1_s { + float16_t d; + uint8_t qs[QUANT_K_IQ1_S/8]; + uint16_t qh[QUANT_K_IQ1_S/32]; +}; + +#define QUANT_K_IQ1_M 256 +#define QUANT_R_IQ1_M 1 + +struct block_iq1_m { + uint8_t qs[QUANT_K_IQ1_M/8]; + uint8_t qh[QUANT_K_IQ1_M/16]; + uint16_t scales[QUANT_K_IQ1_M/64]; +}; + +#if defined(DATA_A_IQ1_S) +#define QUANT_K QUANT_K_IQ1_S +#define QUANT_R QUANT_R_IQ1_S +#define A_TYPE block_iq1_s +#endif + +#if defined(DATA_A_IQ1_M) +#define QUANT_K QUANT_K_IQ1_M +#define QUANT_R QUANT_R_IQ1_M +#define A_TYPE block_iq1_m +#endif + +#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) +#define IQ1S_DELTA 0.125f +#define IQ1M_DELTA 0.125f + +// Packed IQ1S grid where every 2 vec8 are encoded on 32 bits (2 bits per coordinate). +const uint[1024] iq1s_grid_const = { + 0xfffdffff, 0xfff7fff0, 0xffccfff5, 0xffdfffc0, 0xffd7ffdd, 0xff30ffd5, 0xff03ff0c, 0xff10ff01, + 0xff7dff7f, 0xff75ff77, 0xff5fff40, 0xff57ff5d, 0xfcf3ff55, 0xfcccfcf0, 0xfcc1fcc3, 0xfcc5fcc4, + 0xfc3cfcd0, 0xfc34fc31, 0xfc00fc0d, 0xfc1cfc05, 0xfc11fc13, 0xfc70fc17, 0xfc43fc4c, 0xfc50fc41, + 0xfdfdfdff, 0xfdf5fdf7, 0xfddffdc0, 0xfdd7fddd, 0xfd30fdd5, 0xfd04fd0c, 0xfd14fd13, 0xfd7dfd7f, + 0xfd75fd77, 0xfd40fd4c, 0xfd5ffd44, 0xfd57fd5d, 0xf3ccfd55, 0xf3c1f3c3, 0xf33cf3d0, 0xf300f334, + 0xf313f305, 0xf34cf310, 0xf350f344, 0xf0f3f0fc, 0xf0f1f0f0, 0xf0c7f0c0, 0xf0d4f0c5, 0xf030f03f, + 0xf00ff035, 0xf003f00c, 0xf001f000, 0xf01ff004, 0xf010f01d, 0xf015f017, 0xf04cf07c, 0xf047f040, + 0xf05cf045, 0xf050f053, 0xf054f051, 0xf1c4f1c3, 0xf133f13c, 0xf10df10f, 0xf107f100, 0xf11cf11f, + 0xf114f111, 0xf14cf170, 0xf144f143, 0xf7fdf7ff, 0xf7f5f7f7, 0xf7dff7c0, 0xf7d7f7dd, 0xf730f7d5, + 0xf701f70c, 0xf77ff710, 0xf777f77d, 0xf740f775, 0xf75df75f, 0xf755f757, 0xf4ccf4f0, 0xf4c4f4c3, + 0xf4d0f4d3, 0xf40ff43c, 0xf400f40c, 0xf413f41c, 0xf44cf414, 0xf441f443, 0xf450f444, 0xf5fdf5ff, + 0xf5f5f5f7, 0xf5dff5c0, 0xf5d7f5dd, 0xf530f5d5, 0xf504f50c, 0xf510f51c, 0xf57df57f, 0xf577f570, + 0xf540f575, 0xf55df55f, 0xf555f557, 0xcfcccfcf, 0xcfc4cfc3, 0xcfd0cfd3, 0xcf33cf3c, 0xcf00cf0f, + 0xcf1ccf07, 0xcf10cf13, 0xcf4ccf14, 0xcf41cf43, 0xcf50cf5c, 0xccf3ccfc, 0xccf4ccf1, 0xcccdcccf, + 0xccc7ccc0, 0xccd3ccdc, 0xcc30ccd4, 0xcc0fcc35, 0xcc0dcc0c, 0xcc00cc03, 0xcc04cc01, 0xcc10cc1f, + 0xcc4dcc73, 0xcc5ccc40, 0xcdcccc53, 0xcdc1cdc3, 0xcd3fcdd0, 0xcd34cd31, 0xcd00cd0d, 0xcd05cd07, + 0xcd11cd13, 0xcd4ccd70, 0xcd41cd43, 0xc3fccd50, 0xc3f4c3f1, 0xc3c0c3c3, 0xc3c4c3c7, 0xc3d1c3dc, + 0xc330c33c, 0xc337c331, 0xc30cc335, 0xc300c303, 0xc304c301, 0xc310c31d, 0xc373c317, 0xc34fc374, + 0xc340c343, 0xc344c347, 0xc35cc345, 0xc350c353, 0xc0fdc354, 0xc0f5c0f0, 0xc0c3c0cc, 0xc0c1c0c0, + 0xc0dfc0c4, 0xc0d0c0dd, 0xc0d5c0d7, 0xc033c03c, 0xc031c030, 0xc00dc00c, 0xc000c003, 0xc004c001, + 0xc01cc005, 0xc010c013, 0xc014c011, 0xc07dc07f, 0xc070c073, 0xc075c077, 0xc04cc04f, 0xc040c043, + 0xc044c041, 0xc05fc045, 0xc050c05d, 0xc1f3c1fc, 0xc1f1c1f0, 0xc1c1c1c0, 0xc1c5c1c7, 0xc1d1c1dc, + 0xc13dc13f, 0xc130c133, 0xc135c137, 0xc100c10c, 0xc107c101, 0xc11cc104, 0xc110c113, 0xc114c117, + 0xc171c115, 0xc14dc175, 0xc153c140, 0xc7ccc154, 0xc7d0c7c1, 0xc733c73c, 0xc734c731, 0xc700c70f, + 0xc705c707, 0xc71cc71f, 0xc711c713, 0xc770c714, 0xc743c74c, 0xc4cfc750, 0xc4c0c4cd, 0xc4dcc4c5, + 0xc43dc4d0, 0xc430c433, 0xc40cc437, 0xc400c403, 0xc404c401, 0xc41fc405, 0xc415c410, 0xc44cc474, + 0xc440c44d, 0xc45cc447, 0xc454c451, 0xc5c1c5f4, 0xc5d1c5d3, 0xc531c533, 0xc50fc534, 0xc500c50d, + 0xc51cc507, 0xc514c511, 0xc54cc570, 0xc545c541, 0xdffddfff, 0xdff5dff7, 0xdfdfdfc0, 0xdfd0dfdd, + 0xdfd5dfd7, 0xdf0cdf30, 0xdf1cdf04, 0xdf7fdf10, 0xdf77df7d, 0xdf40df75, 0xdf5ddf5f, 0xdf57df50, + 0xdcf0df55, 0xdcc3dccc, 0xdcd0dcc4, 0xdc33dc3d, 0xdc00dc34, 0xdc05dc07, 0xdc13dc1c, 0xdc11dc10, + 0xdc4fdc70, 0xdc44dc41, 0xddfcdc50, 0xddf5ddf7, 0xddc0ddcc, 0xdddddddf, 0xddd5ddd7, 0xdd0cdd30, + 0xdd04dd01, 0xdd7cdd10, 0xdd75dd77, 0xdd40dd4c, 0xdd5ddd5f, 0xdd55dd57, 0xd3c3d3f0, 0xd3c4d3c1, + 0xd333d3d0, 0xd331d330, 0xd30dd334, 0xd307d300, 0xd311d305, 0xd34cd370, 0xd344d343, 0xd350d35c, + 0xd0c0d0f4, 0xd0d4d0dc, 0xd030d03f, 0xd00cd037, 0xd000d003, 0xd01dd004, 0xd017d010, 0xd04fd074, + 0xd040d043, 0xd045d047, 0xd053d05c, 0xd054d051, 0xd1cfd1f0, 0xd1c4d1cd, 0xd13cd1d0, 0xd100d134, + 0xd11cd11f, 0xd173d114, 0xd14fd171, 0xd7ffd145, 0xd7f7d7fd, 0xd7c0d7f5, 0xd7ddd7df, 0xd7d5d7d7, + 0xd70cd730, 0xd710d703, 0xd77dd77f, 0xd775d777, 0xd75dd75f, 0xd755d757, 0xd4ccd4f4, 0xd4c4d4c3, + 0xd431d4d0, 0xd40dd434, 0xd41cd400, 0xd411d413, 0xd470d414, 0xd441d44f, 0xd453d444, 0xd5ffd450, + 0xd5f7d5fd, 0xd5dfd5f5, 0xd5d7d5dd, 0xd530d5d5, 0xd501d50c, 0xd510d504, 0xd57dd57f, 0xd575d577, + 0xd55fd540, 0xd557d55d, 0x3ff0d555, 0x3fc13fcc, 0x3f343fd0, 0x3f003f0d, 0x3f053f07, 0x3f133f1c, + 0x3f433f11, 0x3f5c3f44, 0x3cff3f51, 0x3cf33cfc, 0x3cf43cf1, 0x3cc03ccd, 0x3cc73cc1, 0x3cdc3cc5, + 0x3cd43cd1, 0x3c373c30, 0x3c0c3c35, 0x3c003c03, 0x3c043c01, 0x3c103c05, 0x3c153c17, 0x3c733c7c, + 0x3c4f3c71, 0x3c403c4d, 0x3c5c3c5f, 0x3df03c5d, 0x3dc33dcc, 0x3dd03dc1, 0x3d0d3d3c, 0x3d053d00, + 0x3d143d13, 0x3d433d74, 0x33fc3d50, 0x33c433c0, 0x333033d4, 0x33353337, 0x3303330c, 0x33013300, + 0x331d331c, 0x33173310, 0x337c3315, 0x33743371, 0x334d334f, 0x335f3340, 0x3354335c, 0x30fd30fc, + 0x30f530f0, 0x30c330cc, 0x30c130c0, 0x30df30c4, 0x30d530d0, 0x3033303c, 0x30313030, 0x300f3034, + 0x3003300c, 0x30013000, 0x30043007, 0x3013301c, 0x30113010, 0x307d3014, 0x30703073, 0x304c3077, + 0x30403043, 0x30443041, 0x30503045, 0x30553057, 0x31f031fc, 0x31c331f4, 0x31c731c0, 0x31dc31c5, + 0x31d431d3, 0x313d313f, 0x31373130, 0x310c310f, 0x3100310d, 0x31043101, 0x3110311d, 0x317c3117, + 0x31753170, 0x31403143, 0x3153315c, 0x37f03151, 0x37c037cc, 0x37d037c5, 0x3734373d, 0x3700370f, + 0x371c3707, 0x37113713, 0x37703714, 0x3743374c, 0x37443741, 0x34fc3750, 0x34f134f0, 0x34cf34f5, + 0x34c034c3, 0x34dc34c7, 0x34d134d3, 0x3430343f, 0x340c3435, 0x3403340d, 0x34013400, 0x341f3404, + 0x3410341d, 0x34153411, 0x34743471, 0x3440344d, 0x34473441, 0x3453345c, 0x34543451, 0x353335c1, + 0x35343531, 0x35073500, 0x35133505, 0x35433514, 0x0ffc3550, 0x0ff00ff3, 0x0ff40ff1, 0x0fc00fcd, + 0x0fdc0fc5, 0x0fd40fd3, 0x0f300f3f, 0x0f0c0f37, 0x0f000f03, 0x0f040f01, 0x0f170f10, 0x0f740f71, + 0x0f470f40, 0x0f5c0f5f, 0x0f540f51, 0x0cf70cf0, 0x0cf50cf4, 0x0cc30ccc, 0x0cc10cc0, 0x0cc40cc7, + 0x0cd00cdf, 0x0cd70cd1, 0x0c3c0cd5, 0x0c300c33, 0x0c340c31, 0x0c0c0c0f, 0x0c030c0d, 0x0c010c00, + 0x0c040c07, 0x0c1c0c05, 0x0c100c13, 0x0c140c11, 0x0c700c7d, 0x0c430c4c, 0x0c410c40, 0x0c5f0c44, + 0x0c550c50, 0x0df10dfc, 0x0dc00dcd, 0x0ddc0dc5, 0x0d3d0dd3, 0x0d350d30, 0x0d030d0c, 0x0d010d00, + 0x0d1d0d04, 0x0d700d10, 0x0d4d0d4f, 0x0d440d40, 0x0d530d45, 0x03f003f3, 0x03c303cc, 0x03c103c0, + 0x03c403c7, 0x03d003dc, 0x03d503d7, 0x0333033c, 0x03310330, 0x03350334, 0x030c030f, 0x03000303, + 0x03070301, 0x03050304, 0x031d031c, 0x03100313, 0x03140311, 0x0377037f, 0x034c0375, 0x03400343, + 0x03440341, 0x0353035c, 0x03550350, 0x00fd00fc, 0x00f000f3, 0x00f400f1, 0x00cc00cf, 0x00c300cd, + 0x00c100c0, 0x00c500c4, 0x00d300dc, 0x00d100d0, 0x003f00d4, 0x003d003c, 0x00300033, 0x00370031, + 0x000f0034, 0x000d000c, 0x00000003, 0x00070001, 0x00050004, 0x001c001f, 0x00100013, 0x00170011, + 0x00150014, 0x0073007c, 0x00740070, 0x004f0075, 0x0043004c, 0x00410040, 0x00440047, 0x0053005c, + 0x00510050, 0x01ff0054, 0x01fd01fc, 0x01f101f3, 0x01f401f7, 0x01c301cc, 0x01c701c0, 0x01df01c4, + 0x01dd01dc, 0x01d001d3, 0x01d701d1, 0x013c01d4, 0x01310130, 0x01340137, 0x010f0135, 0x010d010c, + 0x01000103, 0x01070101, 0x01050104, 0x0113011c, 0x01140110, 0x0170017d, 0x01770171, 0x01750174, + 0x0140014c, 0x015d0145, 0x01510150, 0x01540157, 0x07f007f3, 0x07f407f1, 0x07c007cf, 0x07dc07c7, + 0x073007d5, 0x07350737, 0x0703070c, 0x07010700, 0x07040707, 0x071d071f, 0x07100713, 0x0774077d, + 0x074d074f, 0x07470740, 0x0754075c, 0x04fd04fc, 0x04f504f0, 0x04c304cc, 0x04c104c0, 0x04d004c4, + 0x0433043c, 0x04310430, 0x040f0434, 0x040d040c, 0x04000403, 0x04070401, 0x04050404, 0x0413041c, + 0x04110410, 0x047c0414, 0x04740470, 0x0443044c, 0x04410440, 0x04440447, 0x05f30450, 0x05c005f7, + 0x05df05c5, 0x05d105d0, 0x053005d4, 0x05340537, 0x0500050c, 0x05070501, 0x051d0504, 0x05170510, + 0x057c0515, 0x054d0575, 0x05410540, 0x05450547, 0x1ff0055c, 0x1fc11fc3, 0x1fd01fc4, 0x1f0f1f33, + 0x1f011f00, 0x1f051f07, 0x1f131f1c, 0x1f141f11, 0x1f411f7c, 0x1cfc1f50, 0x1cf11cf3, 0x1ccd1cf4, + 0x1cdc1cc0, 0x1cd11cdd, 0x1c301cd4, 0x1c0c1c34, 0x1c011c00, 0x1c101c04, 0x1c151c11, 0x1c751c73, + 0x1c401c4d, 0x1c511c5c, 0x1dcc1c54, 0x1dc41dc1, 0x1d3c1d3f, 0x1d001d31, 0x1d071d01, 0x1d701d1f, + 0x1d411d4c, 0x13cc1d50, 0x13c013cd, 0x13c513c1, 0x13d113dc, 0x133f13d4, 0x1330133d, 0x13351337, + 0x1303130c, 0x13011300, 0x13051304, 0x131d131f, 0x13731310, 0x13741370, 0x134d134f, 0x13401343, + 0x13471341, 0x135c1345, 0x13541353, 0x10f710f0, 0x10cc10f5, 0x10c110c0, 0x103310c4, 0x10311030, + 0x100f1034, 0x1003100c, 0x10011000, 0x101c1004, 0x10101013, 0x10141011, 0x10741071, 0x104c1075, + 0x10411040, 0x10451044, 0x1050105d, 0x10571051, 0x11f411fd, 0x11df11c0, 0x11d711d1, 0x113f11d4, + 0x11371130, 0x110c1135, 0x11001103, 0x11071101, 0x111f1105, 0x11171110, 0x117d117f, 0x11751170, + 0x11411143, 0x11441147, 0x1153115f, 0x11551151, 0x17c417c1, 0x173c17d0, 0x1700170d, 0x171c1705, + 0x17701714, 0x1747174c, 0x14fc1751, 0x14cf14f3, 0x14dc14c0, 0x14d114d3, 0x143f14d4, 0x1430143c, + 0x14371431, 0x1403140c, 0x14011400, 0x141f1404, 0x14151410, 0x1473147d, 0x14401475, 0x1453145c, + 0x14541450, 0x15c115cc, 0x153c15c7, 0x15341533, 0x1500150f, 0x15051507, 0x15101513, 0x15711514, + 0x15471543, 0x15511545, 0x7ffd7fff, 0x7ff57ff7, 0x7fdd7fdf, 0x7fd57fd7, 0x7f0f7f30, 0x7f037f0c, + 0x7f047f01, 0x7f7f7f10, 0x7f777f7d, 0x7f407f75, 0x7f5d7f5f, 0x7f557f57, 0x7ccc7cf0, 0x7cc17cc3, + 0x7cd07cc4, 0x7c337c3c, 0x7c0f7c34, 0x7c007c0d, 0x7c077c01, 0x7c137c04, 0x7c147c11, 0x7c747c70, + 0x7c417c43, 0x7c507c44, 0x7dfd7dff, 0x7df57df7, 0x7ddf7dc0, 0x7dd77ddd, 0x7d0c7dd5, 0x7d047d03, + 0x7d7f7d10, 0x7d777d7d, 0x7d407d75, 0x7d5d7d5f, 0x7d557d57, 0x73c473c3, 0x7333733c, 0x7300730c, + 0x731c7305, 0x73147313, 0x73447343, 0x70f470fc, 0x70c070cd, 0x70d170c5, 0x703f70d4, 0x7030703c, + 0x700c7037, 0x70007003, 0x70047001, 0x70107005, 0x70177011, 0x707c7015, 0x70717073, 0x704f7074, + 0x7040704d, 0x70517047, 0x71c171cc, 0x71d071c4, 0x7133713c, 0x71357134, 0x7100710f, 0x71057104, + 0x7111711c, 0x71707115, 0x7145714c, 0x77ff7153, 0x77f777fd, 0x77c077f5, 0x77dd77df, 0x77d577d7, + 0x7730773c, 0x7703770c, 0x77107704, 0x777f7714, 0x7777777d, 0x77407775, 0x775d775f, 0x77557757, + 0x74f174f0, 0x74c374cc, 0x74d074c1, 0x7433743c, 0x74347431, 0x740d740f, 0x74057400, 0x7413741c, + 0x74417470, 0x74507444, 0x75fd75ff, 0x75f575f7, 0x75df75c0, 0x75d775dd, 0x753075d5, 0x7503750c, + 0x757f7501, 0x7577757d, 0x75407575, 0x755d755f, 0x75557557, 0x4fcc4ff0, 0x4fc74fc1, 0x4fd04fc4, + 0x4f314f3c, 0x4f004f34, 0x4f054f07, 0x4f154f14, 0x4f4c4f70, 0x4f414f43, 0x4f504f44, 0x4cf34cfc, + 0x4cf44cf1, 0x4cc04ccf, 0x4cc54cc7, 0x4cd34cdc, 0x4cd44cd1, 0x4c304c3f, 0x4c0c4c0f, 0x4c004c03, + 0x4c044c01, 0x4c104c1d, 0x4c714c73, 0x4c404c4d, 0x4c5c4c47, 0x4c514c53, 0x4df04c54, 0x4dc34dcc, + 0x4dd04dc4, 0x4d314d33, 0x4d0f4d34, 0x4d004d0d, 0x4d114d07, 0x4d704d14, 0x4d414d43, 0x43fc4d54, + 0x43f143f3, 0x43c043cf, 0x43d143c7, 0x4335433f, 0x4303430c, 0x43014300, 0x43044307, 0x431c431f, + 0x4310431d, 0x43714373, 0x4343434d, 0x43474340, 0x4354435c, 0x40f040ff, 0x40f540f7, 0x40cc40cf, + 0x40c040c3, 0x40c440c1, 0x40d040dc, 0x40d540d4, 0x4033403c, 0x40314030, 0x400f4034, 0x400d400c, + 0x40004003, 0x40074001, 0x40054004, 0x4013401c, 0x40114010, 0x407c4014, 0x40774070, 0x404d404c, + 0x40404043, 0x40444041, 0x405f4045, 0x4050405d, 0x40554057, 0x41f341fc, 0x41c041cf, 0x41df41c4, + 0x41d441d1, 0x41374130, 0x410c4134, 0x4100410d, 0x41044101, 0x41174110, 0x4173417d, 0x41754174, + 0x4143414d, 0x41534140, 0x41544151, 0x47c147f0, 0x47d047c4, 0x4731473c, 0x470d470f, 0x47014700, + 0x47134705, 0x47704710, 0x4741474c, 0x47504744, 0x44f144f3, 0x44cf44f4, 0x44c044cd, 0x44c544c7, + 0x44dc44df, 0x44d144d3, 0x443d443f, 0x44374430, 0x440c4435, 0x44004403, 0x44044401, 0x4410441d, + 0x44154411, 0x4473447c, 0x444d444f, 0x44454440, 0x4451445c, 0x45c045f0, 0x453345d0, 0x45344531, + 0x4500450f, 0x451c4507, 0x454c4570, 0x45404543, 0x5fff4541, 0x5ff75ffd, 0x5fc05ff5, 0x5fdd5fdf, + 0x5fd55fd7, 0x5f0c5f30, 0x5f015f03, 0x5f7f5f04, 0x5f775f7d, 0x5f405f75, 0x5f5d5f5f, 0x5f555f57, + 0x5cf45cf0, 0x5cc35ccc, 0x5cc45cc1, 0x5c315cc5, 0x5c0c5c34, 0x5c075c00, 0x5c1c5c05, 0x5c705c13, + 0x5c4d5c4f, 0x5c445c41, 0x5df75dfd, 0x5dcf5df5, 0x5ddd5dc4, 0x5dd55dd7, 0x5d0c5d30, 0x5d045d01, + 0x5d7f5d10, 0x5d775d7d, 0x5d405d75, 0x5d5d5d5f, 0x5d555d57, 0x53d053c4, 0x5333533c, 0x5303530f, + 0x53075300, 0x531c5305, 0x53115310, 0x53145317, 0x50f15370, 0x50cf50f4, 0x50c050cd, 0x50d150c7, + 0x503d50d4, 0x500c5030, 0x50005003, 0x50045001, 0x50155010, 0x5073507c, 0x50715070, 0x504d5074, + 0x50475040, 0x51cc51f0, 0x51c551c1, 0x51d051dc, 0x51315133, 0x510d5135, 0x51015100, 0x511f5107, + 0x5171511d, 0x5140514f, 0x51445141, 0x5153515c, 0x57ff5151, 0x57f757fd, 0x57df57f5, 0x57d757dd, + 0x570c57d5, 0x57015703, 0x577f5704, 0x5777577d, 0x57405775, 0x575d575f, 0x57555757, 0x54c354f0, + 0x54dc54c4, 0x543c54d0, 0x5400540f, 0x541c5405, 0x54145411, 0x5441544f, 0x55fd55ff, 0x55f555f7, + 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 +}; + +shared uint16_t iq1s_grid[2048]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq1s_grid_const.length(); i += wgsize.x) { + u16vec2 g = unpack16(iq1s_grid_const[i]); + iq1s_grid[2*i+0] = g.x; + iq1s_grid[2*i+1] = g.y; + } + barrier(); +} +#endif + #define QUANT_K_IQ2_XXS 256 #define QUANT_R_IQ2_XXS 1 @@ -380,6 +561,7 @@ const uvec2[256] iq2xxs_grid_const = { shared uvec2 iq2xxs_grid[256]; +#define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync @@ -547,6 +729,7 @@ const uvec2 iq2xs_grid_const[512] = { shared uvec2 iq2xs_grid[512]; +#define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync @@ -836,6 +1019,7 @@ const uvec2 iq2s_grid_const[1024] = { shared uvec2 iq2s_grid[1024]; +#define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync @@ -904,6 +1088,7 @@ const uint32_t iq3xxs_grid_const[256] = { shared uint32_t iq3xxs_grid[256]; +#define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync @@ -1011,6 +1196,7 @@ const uint32_t iq3s_grid_const[512] = { shared uint32_t iq3s_grid[512]; +#define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync @@ -1073,6 +1259,7 @@ const int8_t kvalues_iq4nl_const[16] = { shared FLOAT_TYPE kvalues_iq4nl[16]; +#define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 77e7e1148b4..601cd4e7d30 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -55,6 +55,8 @@ const std::vector type_names = { "q4_k", "q5_k", "q6_k", + "iq1_s", + "iq1_m", "iq2_xxs", "iq2_xs", "iq2_s", @@ -182,6 +184,13 @@ std::string to_uppercase(const std::string& input) { return result; } +bool string_starts_with(const std::string& str, const std::string& prefix) { + if (prefix.size() > str.size()) { + return false; + } + return std::equal(prefix.begin(), prefix.end(), str.begin()); +} + bool string_ends_with(const std::string& str, const std::string& suffix) { if (suffix.size() > str.size()) { return false; @@ -387,7 +396,7 @@ void process_shaders() { for (const auto& tname : type_names) { // mul mat vec std::string data_a_key = "DATA_A_" + to_uppercase(tname); - std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; + std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); From c65a30105a3d17accd5c5c0bf9081b8f8748ffc2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 15 Feb 2025 16:40:57 +0200 Subject: [PATCH 39/58] repo : update links to new url (llama/11886) * repo : update links to new url ggml-ci * cont : more urls ggml-ci --- ggml/include/ggml-cpu.h | 2 +- ggml/include/ggml-metal.h | 2 +- ggml/src/ggml-cpu/ggml-cpu.c | 4 ++-- ggml/src/ggml-metal/ggml-metal.m | 2 +- ggml/src/ggml-metal/ggml-metal.metal | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index 3aa71badb5f..d23c6b262e2 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -8,7 +8,7 @@ extern "C" { #endif // the compute plan that needs to be prepared for ggml_graph_compute() - // since https://github.com/ggerganov/ggml/issues/287 + // since https://github.com/ggml-org/ggml/issues/287 struct ggml_cplan { size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()` uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()` diff --git a/ggml/include/ggml-metal.h b/ggml/include/ggml-metal.h index 669c1f84ae6..a6106944234 100644 --- a/ggml/include/ggml-metal.h +++ b/ggml/include/ggml-metal.h @@ -45,7 +45,7 @@ GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend); GGML_DEPRECATED( GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size), - "obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713"); + "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713"); GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 0cbf8318bed..dbef5df2111 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1816,7 +1816,7 @@ inline static float ggml_silu_f32(float x) { #if __FINITE_MATH_ONLY__ #error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix" -#error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461" +#error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461" #endif #if defined(__ARM_NEON) && defined(__aarch64__) @@ -7574,7 +7574,7 @@ UseGgmlGemm2:; int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread. - // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915 + // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggml-org/llama.cpp/pull/6915 // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that. if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) { // distribute the thread work across the inner or outer loop based on which one is larger diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 944d90af344..0add6b51a40 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1983,7 +1983,7 @@ static void ggml_metal_encode_node( const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); // TODO: add ggml_metal_kargs struct - // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6) + // TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6) [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; if (id_src1) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 44f04c909bf..da415184b17 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1058,7 +1058,7 @@ kernel void kernel_soft_max( } // This barrier fixes a failing test - // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335 threadgroup_barrier(mem_flags::mem_none); float sum = simd_sum(lsum); @@ -1163,7 +1163,7 @@ kernel void kernel_soft_max_4( const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; // This barrier fixes a failing test - // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335 threadgroup_barrier(mem_flags::mem_none); float sum = simd_sum(lsum); From c9dfdfa4401fd598bc773bcd89bb58f3315b3828 Mon Sep 17 00:00:00 2001 From: Adrian Kretz Date: Sat, 15 Feb 2025 19:39:20 +0100 Subject: [PATCH 40/58] metal : optimize dequant q6_K kernel (llama/11892) --- ggml/src/ggml-metal/ggml-metal.metal | 33 ++++++++++++++++++---------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index da415184b17..83e7ac9f411 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -373,24 +373,33 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg template void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { const half d_all = xb->d; - device const uint8_t * ql = (device const uint8_t *)xb->ql; - device const uint8_t * qh = (device const uint8_t *)xb->qh; + device const uint16_t * ql = (device const uint16_t *)xb->ql; + device const uint16_t * qh = (device const uint16_t *)xb->qh; device const int8_t * scales = (device const int8_t *)xb->scales; - ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); - qh = qh + 32*(il/8) + 16*(il&1); + ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1); + qh = qh + 16*(il/8) + 8*(il&1); float sc = scales[(il%2) + 2 * ((il/2))]; il = (il/2) & 3; - const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; - const float coef = il>1 ? 1.f/16.f : 1.f; + const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303); + const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F; const float ml = d_all * sc * 32.f; - const float dl = d_all * sc * coef; - for (int i = 0; i < 16; ++i) { - const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) - : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); - reg[i/4][i%4] = dl * q - ml; + const float dl0 = d_all * sc; + const float dl1 = dl0 / 256.f; + const float dl2 = dl0 / (256.f * 256.f); + const float dl3 = dl0 / (256.f * 256.f * 256.f); + const uint8_t shr_h = il>2 ? 2 : 0; + const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4); + const uint8_t shr_l = il>1 ? 4 : 0; + for (int i = 0; i < 4; ++i) { + const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2; + const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1; + const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l); + reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml; + reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml; + reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml; + reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml; } } From 3dba9f7a11c90fe4f103c24d5c7eba25193f7a39 Mon Sep 17 00:00:00 2001 From: Hale Chan Date: Sun, 16 Feb 2025 14:50:26 +0800 Subject: [PATCH 41/58] metal : fix the crash caused by the lack of residency set support on Intel Macs. (llama/11904) --- ggml/src/ggml-metal/ggml-metal.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 0add6b51a40..087e7f58149 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -24,7 +24,7 @@ #endif // create residency sets only on macOS >= 15.0 -#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \ +#if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \ TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \ TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \ TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000 From 1a9755441445c99a506d068c2379b4bfaed08b26 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 16 Feb 2025 01:52:23 -0600 Subject: [PATCH 42/58] vulkan: support multi/vision rope, and noncontiguous rope (llama/11902) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 58 ++++++++++++++---- .../ggml-vulkan/vulkan-shaders/rope_head.comp | 4 ++ .../vulkan-shaders/rope_multi.comp | 60 +++++++++++++++++++ .../ggml-vulkan/vulkan-shaders/rope_neox.comp | 34 ++++++----- .../ggml-vulkan/vulkan-shaders/rope_norm.comp | 34 ++++++----- .../vulkan-shaders/rope_vision.comp | 47 +++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 8 +++ 7 files changed, 204 insertions(+), 41 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 68f2ea14baf..88f31c1ef8b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -251,6 +251,8 @@ struct vk_device_struct { vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; + vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; + vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_argsort_f32; vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; @@ -494,6 +496,10 @@ struct vk_op_rope_push_constants { float corr_dims[2]; float theta_scale; uint32_t has_ff; + uint32_t ne02; + uint32_t s1; + uint32_t s2; + int32_t sections[4]; }; struct vk_op_soft_max_push_constants { @@ -2180,13 +2186,19 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); if (device->float_controls_rte_fp16) { ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); } else { ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); } ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); @@ -5307,6 +5319,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const { const int mode = ((const int32_t *) dst->op_params)[2]; const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_neox) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { @@ -5315,6 +5329,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { return ctx->device->pipeline_rope_neox_f16; } + } else if (is_mrope && !is_vision) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_multi_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_multi_f16; + } + } else if (is_vision) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_vision_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_vision_f16; + } } else { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_rope_norm_f32; @@ -5385,6 +5413,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_REPEAT: + case GGML_OP_ROPE: return true; default: return false; @@ -6149,7 +6178,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { const int n_dims = ((int32_t *) dst->op_params)[1]; - // const int mode = ((int32_t *) dst->op_params)[2]; + const int mode = ((int32_t *) dst->op_params)[2]; // const int n_ctx = ((int32_t *) dst->op_params)[3]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; const float freq_base = ((float *) dst->op_params)[5]; @@ -6158,16 +6187,24 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons const float attn_factor = ((float *) dst->op_params)[8]; const float beta_fast = ((float *) dst->op_params)[9]; const float beta_slow = ((float *) dst->op_params)[10]; + int sections[4] {}; + if (mode & GGML_ROPE_TYPE_MROPE) { + memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4); + } float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims); + uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type); + uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type); + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, - src2 != nullptr, + src2 != nullptr, (uint32_t)src0->ne[2], s1, s2, + sections[0], sections[1], sections[2], sections[3], }, dryrun); } @@ -8264,16 +8301,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_REPEAT: return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); case GGML_OP_ROPE: - { - const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { - return false; - } - if (mode & GGML_ROPE_TYPE_VISION) { - return false; - } - return ggml_is_contiguous(op->src[0]); - } case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -8831,7 +8858,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { const float attn_factor = ((float *) tensor->op_params)[8]; const float beta_fast = ((float *) tensor->op_params)[9]; const float beta_slow = ((float *) tensor->op_params)[10]; - tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + if (mode & GGML_ROPE_TYPE_MROPE) { + int32_t *sections = ((int32_t *) tensor->op_params) + 11; + tensor_clone = ggml_rope_multi(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } else { + tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } } else if (tensor->op == GGML_OP_UNARY) { switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_SILU: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp index 574b51ca553..38075b75557 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp @@ -25,6 +25,10 @@ layout (push_constant) uniform parameter { float corr_dims[2]; float theta_scale; uint has_ff; + uint ne02; + uint s1; + uint s2; + int sections[4]; } p; float rope_yarn_ramp(const float low, const float high, const uint i0) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp new file mode 100644 index 00000000000..4f5b1a0ecaf --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -0,0 +1,60 @@ +#version 450 + +#include "rope_head.comp" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + if (i0 >= p.n_dims) { + const uint i = row_dst*ne0 + i0; + + data_d[i + 0] = data_a[i + 0]; + data_d[i + 1] = data_a[i + 1]; + + return; + } + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + + const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < p.sections[0]) { + theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= p.sections[0] && sector < sec_w) { + theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w + p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + } + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims/2]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index 83b46b69b2a..db775c456ca 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -3,15 +3,18 @@ #include "rope_head.comp" void main() { - const uint col = gl_GlobalInvocationID.y * 2; - const uint row = gl_GlobalInvocationID.x; + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; - if (col >= p.ncols) { + if (i0 >= ne0) { return; } - if (col >= p.n_dims) { - const uint i = row*p.ncols + col; + const uint row_dst = gl_GlobalInvocationID.x; + + if (i0 >= p.n_dims) { + const uint i = row_dst*ne0 + i0; data_d[i + 0] = data_a[i + 0]; data_d[i + 1] = data_a[i + 1]; @@ -19,19 +22,22 @@ void main() { return; } - const uint i = row*p.ncols + col/2; - const uint i2 = row/p.p_delta_rows; + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; - const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); - const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - const float x0 = float(data_a[i + 0]); - const float x1 = float(data_a[i + p.n_dims/2]); + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims/2]); - data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index e416ad93897..4ad35e549d7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -3,15 +3,18 @@ #include "rope_head.comp" void main() { - const uint col = gl_GlobalInvocationID.y * 2; - const uint row = gl_GlobalInvocationID.x; + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; - if (col >= p.ncols) { + if (i0 >= ne0) { return; } - if (col >= p.n_dims) { - const uint i = row*p.ncols + col; + const uint row_dst = gl_GlobalInvocationID.x; + + if (i0 >= p.n_dims) { + const uint i = row_dst*ne0 + i0; data_d[i + 0] = data_a[i + 0]; data_d[i + 1] = data_a[i + 1]; @@ -19,19 +22,22 @@ void main() { return; } - const uint i = row*p.ncols + col; - const uint i2 = row/p.p_delta_rows; + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0; - const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); - const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - const float x0 = float(data_a[i + 0]); - const float x1 = float(data_a[i + 1]); + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + 1]); - data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta); + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp new file mode 100644 index 00000000000..cedacc4d144 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -0,0 +1,47 @@ +#version 450 + +#include "rope_head.comp" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + + const int sect_dims = p.sections[0] + p.sections[1]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < p.sections[0]) { + const uint p0 = sector; + theta_base = data_pos[channel_x]*pow(p.theta_scale, p0); + } + else if (sector >= p.sections[0] && sector < sec_w) { + const uint p0 = sector - p.sections[0]; + theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0); + } + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 601cd4e7d30..ba9163af27a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -491,6 +491,14 @@ void process_shaders() { string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); From 1c5790a8197294a7539b2c1862bfbe530c6a9e67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9my=20O?= Date: Mon, 17 Feb 2025 07:55:57 +0100 Subject: [PATCH 43/58] vulkan: implement several ops relevant for ggml_opt (llama/11769) * vulkan: support memset_tensor * vulkan: support GGML_OP_SUM * vulkan: implement GGML_OP_ARGMAX * vulkan: implement GGML_OP_SUB * vulkan: implement GGML_OP_COUNT_EQUAL * vulkan: implement GGML_OP_OPT_STEP_ADAMW * vulkan: fix check_results RWKV_WKV6 crash and memory leaks * vulkan: implement GGML_OP_REPEAT_BACK * tests: remove invalid test-backend-ops REPEAT_BACK tests * vulkan: fix COUNT_EQUAL memset using a fillBuffer command --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 583 +++++++++++------- .../ggml-vulkan/vulkan-shaders/argmax.comp | 51 ++ .../vulkan-shaders/count_equal.comp | 31 + .../vulkan-shaders/opt_step_adamw.comp | 42 ++ .../vulkan-shaders/repeat_back.comp | 37 ++ ggml/src/ggml-vulkan/vulkan-shaders/sub.comp | 29 + .../vulkan-shaders/vulkan-shaders-gen.cpp | 7 + 7 files changed, 563 insertions(+), 217 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/sub.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 88f31c1ef8b..131ee1ea044 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -222,6 +222,7 @@ struct vk_device_struct { vk_pipeline pipeline_acc_f32; vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat; vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat; + vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat; vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat; vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat; vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; @@ -232,7 +233,7 @@ struct vk_device_struct { vk_pipeline pipeline_cos_f32; vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; - vk_pipeline pipeline_repeat_f32; + vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16; vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; @@ -255,10 +256,13 @@ struct vk_device_struct { vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_argsort_f32; vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_argmax_f32; + vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; + vk_pipeline pipeline_opt_step_adamw_f32; // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; @@ -2147,6 +2151,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); @@ -2169,6 +2175,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -2203,8 +2210,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); if (device->float_controls_rte_fp16) { ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); @@ -2218,6 +2229,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + for (auto &c : compiles) { c.wait(); } @@ -3783,6 +3796,12 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr } } +static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")"); + + ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); +} + static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")"); @@ -5189,6 +5208,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16; } return nullptr; + case GGML_OP_SUB: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32; + } + return nullptr; case GGML_OP_MUL: if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32; @@ -5250,6 +5274,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_repeat_f32; } return nullptr; + case GGML_OP_REPEAT_BACK: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_repeat_back_f32; + } + return nullptr; case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: @@ -5358,11 +5387,22 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_argsort_f32; } return nullptr; + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_sum_rows_f32; } return nullptr; + case GGML_OP_ARGMAX: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { + return ctx->device->pipeline_argmax_f32; + } + return nullptr; + case GGML_OP_COUNT_EQUAL: + if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) { + return ctx->device->pipeline_count_equal_i32; + } + return nullptr; case GGML_OP_IM2COL: if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_im2col_f32; @@ -5386,6 +5426,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rwkv_wkv6_f32; } return nullptr; + case GGML_OP_OPT_STEP_ADAMW: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_opt_step_adamw_f32; + } + return nullptr; case GGML_OP_LEAKY_RELU: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_leaky_relu_f32; @@ -5403,6 +5448,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_CPY: case GGML_OP_GET_ROWS: case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: @@ -5413,6 +5459,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: case GGML_OP_ROPE: return true; default: @@ -5627,6 +5674,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_RMS_NORM: case GGML_OP_SOFT_MAX: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: { const uint32_t nr = ggml_nrows(src0); if (nr > 262144) { @@ -5637,6 +5685,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { nr, 1, 1 }; } } break; + case GGML_OP_SUM: + // We use GGML_OP_SUM_ROWS with 1 row. + elements = { 1, 1, 1 }; + break; case GGML_OP_GROUP_NORM: { const uint32_t num_groups = dst->op_params[0]; @@ -5683,6 +5735,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { N * OC * OH * OW, 1, 1}; } break; case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_DIV: case GGML_OP_MUL: case GGML_OP_SCALE: @@ -5692,6 +5745,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: case GGML_OP_CPY: case GGML_OP_CONCAT: case GGML_OP_UPSCALE: @@ -5752,6 +5806,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co // im2col uses only src1 and dst buffers ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + } else if (op == GGML_OP_COUNT_EQUAL) { + ggml_vk_sync_buffers(subctx); + // count_equal assumes that destination buffer is initialized with zeroes + ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz); + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); } else if (use_src2) { ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); @@ -5814,6 +5874,21 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const }, dryrun); } +static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); @@ -5972,6 +6047,111 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ); } +static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) { + const ggml_tensor * x = dst->src[0]; + const ggml_tensor * g = dst->src[1]; + const ggml_tensor * gm = dst->src[2]; + const ggml_tensor * gv = dst->src[3]; + const ggml_tensor * p = dst->src[4]; + + GGML_ASSERT(x->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(gm->type == GGML_TYPE_F32); + GGML_ASSERT(gv->type == GGML_TYPE_F32); + GGML_ASSERT(p->type == GGML_TYPE_F32); + GGML_ASSERT(dst->buffer != nullptr); + GGML_ASSERT(ggml_is_contiguous(x)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(gm)); + GGML_ASSERT(ggml_is_contiguous(gv)); + GGML_ASSERT(ggml_is_contiguous(p)); + GGML_ASSERT(ggml_are_same_shape(x, g)); + GGML_ASSERT(ggml_are_same_shape(x, gm)); + GGML_ASSERT(ggml_are_same_shape(x, gv)); + GGML_ASSERT(ggml_nelements(p) == 7); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW); + GGML_ASSERT(pipeline != nullptr); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return; + } + + ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context; + ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context; + ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context; + ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context; + ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context; + + ggml_vk_sync_buffers(subctx); + + vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr; + size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0; + bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, x->data, d_X, x_offset); + ggml_vk_host_get(ctx->device, g->data, d_G, g_offset); + ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset); + ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset); + ggml_vk_host_get(ctx->device, p->data, d_P, p_offset); + + X_uma = d_X != nullptr; + G_uma = d_G != nullptr; + GM_uma = d_GM != nullptr; + GV_uma = d_GV != nullptr; + P_uma = d_P != nullptr; + } + + if (!X_uma) { + d_X = x_buf_ctx->dev_buffer; + x_offset = vk_tensor_offset(x) + x->view_offs; + } + if (!G_uma) { + d_G = g_buf_ctx->dev_buffer; + g_offset = vk_tensor_offset(g) + g->view_offs; + } + if (!GM_uma) { + d_GM = gm_buf_ctx->dev_buffer; + gm_offset = vk_tensor_offset(gm) + gm->view_offs; + } + if (!GV_uma) { + d_GV = gv_buf_ctx->dev_buffer; + gv_offset = vk_tensor_offset(gv) + gv->view_offs; + } + if (!P_uma) { + d_P = p_buf_ctx->dev_buffer; + p_offset = vk_tensor_offset(p) + p->view_offs; + } + + const uint64_t x_size = ggml_nbytes(x); + const uint64_t g_size = ggml_nbytes(g); + const uint64_t gm_size = ggml_nbytes(gm); + const uint64_t gv_size = ggml_nbytes(gv); + const uint64_t p_size = ggml_nbytes(p); + + std::array elements = { (uint32_t)ggml_nelements(x), 1, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_X, x_offset, x_size }, + vk_subbuffer{ d_G, g_offset, g_size }, + vk_subbuffer{ d_GM, gm_offset, gm_size }, + vk_subbuffer{ d_GV, gv_offset, gv_size }, + vk_subbuffer{ d_P, p_offset, p_size }, + }, sizeof(vk_op_push_constants), &pc, elements); +} + +static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t n = ggml_nelements(dst->src[0]); + + ggml_vk_op_f32_opt_step_adamw( + ctx, subctx, dst, + { (uint32_t)n, 0, 0.0f, 0.0f }, + dryrun + ); +} + static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { int * op_params = (int *)dst->op_params; @@ -6105,6 +6285,20 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co }, dryrun); } +static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); +} + static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -6227,10 +6421,22 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c }, dryrun); } +static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); } +static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const int32_t s0 = dst->op_params[0]; const int32_t s1 = dst->op_params[1]; @@ -7095,9 +7301,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod } break; case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: case GGML_OP_GET_ROWS: case GGML_OP_ADD: case GGML_OP_ACC: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: @@ -7120,13 +7328,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: case GGML_OP_ARGSORT: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: case GGML_OP_FLASH_ATTN_EXT: + case GGML_OP_OPT_STEP_ADAMW: break; default: std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; @@ -7147,9 +7359,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod } else { switch (node->op) { case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: case GGML_OP_ACC: case GGML_OP_GET_ROWS: case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: @@ -7171,7 +7385,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: case GGML_OP_ARGSORT: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: @@ -7192,6 +7409,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_REPEAT: ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_REPEAT_BACK: + ggml_vk_repeat_back(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_ACC: ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun); @@ -7204,6 +7425,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_ADD: ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_SUB: + ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_MUL: ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun); @@ -7291,10 +7516,22 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_ARGSORT: ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_SUM: + ggml_vk_sum(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_SUM_ROWS: ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_ARGMAX: + ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_COUNT_EQUAL: + ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_IM2COL: ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); @@ -7329,6 +7566,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_RWKV_WKV6: ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun); + break; + + case GGML_OP_OPT_STEP_ADAMW: + ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun); + break; default: return false; @@ -7380,6 +7622,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_ADD: case GGML_OP_ACC: case GGML_OP_GET_ROWS: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: @@ -7405,13 +7648,18 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_TRANSPOSE: case GGML_OP_NONE: case GGML_OP_ARGSORT: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_OPT_STEP_ADAMW: buf = tensor->buffer; break; @@ -7603,6 +7851,15 @@ static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggm } } +static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; + + uint32_t val32 = (uint32_t)value * 0x01010101; + ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size); +} + static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; @@ -7647,7 +7904,7 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer, /* .get_base = */ ggml_backend_vk_buffer_get_base, /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor, - /* .memset_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, @@ -8300,6 +8557,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } break; case GGML_OP_REPEAT: return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); + case GGML_OP_REPEAT_BACK: + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ROPE: case GGML_OP_NONE: case GGML_OP_RESHAPE: @@ -8313,6 +8572,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return ggml_is_contiguous(op->src[0]); case GGML_OP_ADD: case GGML_OP_ACC: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_CONCAT: @@ -8326,12 +8586,16 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_ARGSORT: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: + case GGML_OP_OPT_STEP_ADAMW: return true; default: return false; @@ -8604,8 +8868,6 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { ggml_tensor * src0 = tensor->src[0]; ggml_tensor * src1 = tensor->src[1]; - ggml_tensor * src2 = tensor->src[2]; - ggml_tensor * src3 = tensor->src[3]; struct ggml_init_params iparams = { /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul, @@ -8615,238 +8877,113 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { struct ggml_context * ggml_ctx = ggml_init(iparams); - struct ggml_tensor * src0_clone = nullptr; - struct ggml_tensor * src1_clone = nullptr; - struct ggml_tensor * src2_clone = nullptr; - struct ggml_tensor * src3_clone = nullptr; - struct ggml_tensor * tensor_clone = nullptr; - - size_t src0_size; - size_t src1_size; - size_t src2_size; - size_t src3_size; - - void * src0_buffer = nullptr; - void * src1_buffer = nullptr; - void * src2_buffer = nullptr; - void * src3_buffer = nullptr; - - if (src0 != nullptr) { - src0_clone = ggml_dup_tensor(ggml_ctx, src0); - - src0_size = ggml_nbytes(src0); + std::array src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; + std::array src_size = {0, 0, 0, 0, 0, 0}; + std::array src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; + const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"}; - src0_buffer = malloc(src0_size); - src0_clone->data = src0_buffer; - if (ggml_backend_buffer_is_host(src0->buffer)) { - memcpy(src0_clone->data, src0->data, src0_size); - memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); - } else if (ggml_backend_buffer_is_vk(src0->buffer)) { - ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; - vk_buffer& buffer_gpu = buf_ctx->dev_buffer; - uint64_t offset = vk_tensor_offset(src0) + src0->view_offs; - if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) { - for (int i3 = 0; i3 < src0->ne[3]; i3++) { - for (int i2 = 0; i2 < src0->ne[2]; i2++) { - const int idx = i3*src0->ne[2] + i2; - ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]); - } - } - - src0_clone->nb[0] = src0->nb[0]; - src0_clone->nb[1] = src0->nb[1]; - for (int i = 2; i < GGML_MAX_DIMS; i++) { - src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1]; - } - } else { - if (offset + src0_size >= buffer_gpu->size) { - src0_size = buffer_gpu->size - offset; - } - ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size); - memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); - } - } else { - GGML_ABORT("fatal error"); - } - - if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { - ggml_vk_print_tensor(src0, "src0"); - } - } - if (src1 != nullptr) { - src1_clone = ggml_dup_tensor(ggml_ctx, src1); - - src1_size = ggml_nbytes(src1); - - src1_buffer = malloc(src1_size); - src1_clone->data = src1_buffer; - if (ggml_backend_buffer_is_host(src1->buffer)) { - memcpy(src1_clone->data, src1->data, src1_size); - memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); - } else if (ggml_backend_buffer_is_vk(src1->buffer)) { - ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; - vk_buffer& buffer_gpu = buf_ctx->dev_buffer; - uint64_t offset = vk_tensor_offset(src1) + src1->view_offs; - if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) { - for (int i3 = 0; i3 < src1->ne[3]; i3++) { - for (int i2 = 0; i2 < src1->ne[2]; i2++) { - const int idx = i3*src1->ne[2] + i2; - ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]); - } - } - - src1_clone->nb[0] = src1->nb[0]; - src1_clone->nb[1] = src1->nb[1]; - for (int i = 2; i < GGML_MAX_DIMS; i++) { - src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1]; - } - } else { - if (offset + src1_size >= buffer_gpu->size) { - src1_size = buffer_gpu->size - offset; - } - ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size); - memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); - } - } else { - GGML_ABORT("fatal error"); - } - - if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { - ggml_vk_print_tensor(src1, "src1"); - } - } - if (src2 != nullptr) { - src2_clone = ggml_dup_tensor(ggml_ctx, src2); - - src2_size = ggml_nbytes(src2); - - src2_buffer = malloc(src2_size); - src2_clone->data = src2_buffer; - if (ggml_backend_buffer_is_host(src2->buffer)) { - memcpy(src2_clone->data, src2->data, src2_size); - memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); - } else if (ggml_backend_buffer_is_vk(src2->buffer)) { - ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src2->buffer->context; - vk_buffer& buffer_gpu = buf_ctx->dev_buffer; - uint64_t offset = vk_tensor_offset(src2) + src2->view_offs; - if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) { - for (int i3 = 0; i3 < src2->ne[3]; i3++) { - for (int i2 = 0; i2 < src2->ne[2]; i2++) { - const int idx = i3*src2->ne[2] + i2; - ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]); - } - } - - src2_clone->nb[0] = src2->nb[0]; - src2_clone->nb[1] = src2->nb[1]; - for (int i = 2; i < GGML_MAX_DIMS; i++) { - src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1]; - } - } else { - if (offset + src2_size >= buffer_gpu->size) { - src2_size = buffer_gpu->size - offset; - } - ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size); - memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); - } - } else { - GGML_ABORT("fatal error"); - } + struct ggml_tensor * tensor_clone = nullptr; - if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { - ggml_vk_print_tensor(src2, "src2"); + for (int i = 0; i < 6; i++) { + ggml_tensor * srci = tensor->src[i]; + if (srci == nullptr) { + continue; } - } - if (src3 != nullptr) { - src3_clone = ggml_dup_tensor(ggml_ctx, src3); + ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci); + size_t srci_size = ggml_nbytes(srci); - src3_size = ggml_nbytes(src3); + src_clone[i] = srci_clone; + src_size[i] = ggml_nbytes(srci); + src_buffer[i] = malloc(srci_size); - src3_buffer = malloc(src3_size); - src3_clone->data = src3_buffer; - if (ggml_backend_buffer_is_host(src3->buffer)) { - memcpy(src3_clone->data, src3->data, src3_size); - memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); - } else if (ggml_backend_buffer_is_vk(src3->buffer)) { - ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context; + srci_clone->data = src_buffer[i]; + if (ggml_backend_buffer_is_host(srci->buffer)) { + memcpy(srci_clone->data, srci->data, srci_size); + memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); + } else if (ggml_backend_buffer_is_vk(srci->buffer)) { + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context; vk_buffer& buffer_gpu = buf_ctx->dev_buffer; - uint64_t offset = vk_tensor_offset(src3) + src3->view_offs; - if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) { - for (int i3 = 0; i3 < src3->ne[3]; i3++) { - for (int i2 = 0; i2 < src3->ne[2]; i2++) { - const int idx = i3*src3->ne[2] + i2; - ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]); + uint64_t offset = vk_tensor_offset(srci) + srci->view_offs; + if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) { + for (int i3 = 0; i3 < srci->ne[3]; i3++) { + for (int i2 = 0; i2 < srci->ne[2]; i2++) { + const int idx = i3*srci->ne[2] + i2; + ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]); } } - src3_clone->nb[0] = src3->nb[0]; - src3_clone->nb[1] = src3->nb[1]; + srci_clone->nb[0] = srci->nb[0]; + srci_clone->nb[1] = srci->nb[1]; for (int i = 2; i < GGML_MAX_DIMS; i++) { - src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1]; + srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1]; } } else { - if (offset + src3_size >= buffer_gpu->size) { - src3_size = buffer_gpu->size - offset; + if (offset + srci_size >= buffer_gpu->size) { + srci_size = buffer_gpu->size - offset; } - ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size); - memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); + ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size); + memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); } } else { GGML_ABORT("fatal error"); } if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { - ggml_vk_print_tensor(src3, "src3"); + ggml_vk_print_tensor(srci, srci_name[i]); } } if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { const float *params = (const float *)tensor->op_params; - tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]); + tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]); } else if (tensor->op == GGML_OP_MUL_MAT) { - tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_MUL_MAT_ID) { - tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone); + tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); + } else if (tensor->op == GGML_OP_SUB) { + tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_MUL) { - tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_DIV) { - tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_CONCAT) { - tensor_clone = ggml_concat(ggml_ctx, src0_clone, src1_clone, *(int *)tensor->op_params); + tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); } else if (tensor->op == GGML_OP_UPSCALE) { - tensor_clone = ggml_upscale_ext(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); } else if (tensor->op == GGML_OP_SCALE) { - tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]); + tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]); } else if (tensor->op == GGML_OP_SQR) { - tensor_clone = ggml_sqr(ggml_ctx, src0_clone); + tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SIN) { - tensor_clone = ggml_sin(ggml_ctx, src0_clone); + tensor_clone = ggml_sin(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_COS) { - tensor_clone = ggml_cos(ggml_ctx, src0_clone); + tensor_clone = ggml_cos(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_CLAMP) { - tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); } else if (tensor->op == GGML_OP_PAD) { - tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]); + tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]); } else if (tensor->op == GGML_OP_REPEAT) { - tensor_clone = ggml_repeat(ggml_ctx, src0_clone, tensor); + tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor); + } else if (tensor->op == GGML_OP_REPEAT_BACK) { + tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor); } else if (tensor->op == GGML_OP_ADD) { - tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_ACC) { - tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); + tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); } else if (tensor->op == GGML_OP_NORM) { - tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); + tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_GROUP_NORM) { - tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); + tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); } else if (tensor->op == GGML_OP_RMS_NORM) { - tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); + tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_SOFT_MAX) { if (src1 != nullptr) { - tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); } else { - tensor_clone = ggml_soft_max(ggml_ctx, src0_clone); + tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]); } } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { - tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params); + tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params); } else if (tensor->op == GGML_OP_ROPE) { const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; @@ -8860,26 +8997,26 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { const float beta_slow = ((float *) tensor->op_params)[10]; if (mode & GGML_ROPE_TYPE_MROPE) { int32_t *sections = ((int32_t *) tensor->op_params) + 11; - tensor_clone = ggml_rope_multi(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); } else { - tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); } } else if (tensor->op == GGML_OP_UNARY) { switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_SILU: - tensor_clone = ggml_silu(ggml_ctx, src0_clone); + tensor_clone = ggml_silu(ggml_ctx, src_clone[0]); break; case GGML_UNARY_OP_GELU: - tensor_clone = ggml_gelu(ggml_ctx, src0_clone); + tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]); break; case GGML_UNARY_OP_GELU_QUICK: - tensor_clone = ggml_gelu_quick(ggml_ctx, src0_clone); + tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]); break; case GGML_UNARY_OP_RELU: - tensor_clone = ggml_relu(ggml_ctx, src0_clone); + tensor_clone = ggml_relu(ggml_ctx, src_clone[0]); break; case GGML_UNARY_OP_TANH: - tensor_clone = ggml_tanh(ggml_ctx, src0_clone); + tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]); break; default: std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; @@ -8887,28 +9024,34 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { if (src1 == nullptr) { - tensor_clone = ggml_dup(ggml_ctx, src0_clone); + tensor_clone = ggml_dup(ggml_ctx, src_clone[0]); tensor_clone->type = tensor->type; } else { - tensor_clone = ggml_cpy(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]); } } else if (tensor->op == GGML_OP_CONT) { - tensor_clone = ggml_cont_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); } else if (tensor->op == GGML_OP_RESHAPE) { - tensor_clone = ggml_reshape_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); } else if (tensor->op == GGML_OP_VIEW) { - tensor_clone = ggml_view_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]); + tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]); } else if (tensor->op == GGML_OP_PERMUTE) { int32_t * params = (int32_t *)tensor->op_params; - tensor_clone = ggml_permute(ggml_ctx, src0_clone, params[0], params[1], params[2], params[3]); + tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]); } else if (tensor->op == GGML_OP_TRANSPOSE) { - tensor_clone = ggml_transpose(ggml_ctx, src0_clone); + tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_GET_ROWS) { - tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_ARGSORT) { - tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params); + tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_SUM) { + tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SUM_ROWS) { - tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone); + tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_ARGMAX) { + tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_COUNT_EQUAL) { + tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_IM2COL) { const int32_t s0 = tensor->op_params[0]; const int32_t s1 = tensor->op_params[1]; @@ -8918,11 +9061,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { const int32_t d1 = tensor->op_params[5]; const bool is_2D = tensor->op_params[6] == 1; - tensor_clone = ggml_im2col(ggml_ctx, src0_clone, src1_clone, s0, s1, p0, p1, d0, d1, is_2D, tensor->type); + tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type); } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { const int32_t dim = tensor->op_params[0]; const int32_t max_period = tensor->op_params[1]; - tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period); + tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period); } else if (tensor->op == GGML_OP_POOL_2D) { enum ggml_op_pool op = static_cast(tensor->op_params[0]); const int32_t k0 = tensor->op_params[1]; @@ -8932,13 +9075,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { const int32_t p0 = tensor->op_params[5]; const int32_t p1 = tensor->op_params[6]; - tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1); + tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1); } else if (tensor->op == GGML_OP_LEAKY_RELU) { const float * op_params = (const float *)tensor->op_params; - tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false); + tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false); } else if (tensor->op == GGML_OP_RWKV_WKV6) { - tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], - tensor->src[4], tensor->src[5]); + tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4], src_clone[5]); + } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { + src_clone[0]->flags = src0->flags; + tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4]); } else { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; @@ -8960,11 +9107,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { memcpy(comp_result, tensor_clone->data, comp_size); memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS); - if (src0 != nullptr) { - free(src0_buffer); - } - if (src1 != nullptr) { - free(src1_buffer); + for (int i = 0; i < 6; i++) { + if (src_buffer[i] != nullptr) { + free(src_buffer[i]); + } } ggml_free(ggml_ctx); @@ -9028,6 +9174,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { } else if (tensor->type == GGML_TYPE_I32) { correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_I64) { + correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); + result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); } else { std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp new file mode 100644 index 00000000000..eaf4da341e3 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp @@ -0,0 +1,51 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +shared FLOAT_TYPE tmpmax[BLOCK_SIZE]; +shared uint tmp[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + + if (col >= p.KX) { + return; + } + A_TYPE amax = data_a[row*p.KX + col]; + tmp[col] = col; + + for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) { + A_TYPE val = data_a[row*p.KX + i]; + if (val > amax) { + amax = val; + tmp[col] = i; + } + } + tmpmax[col] = amax; + + barrier(); + [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) { + if (col < s && col + s < p.KX) { + if (tmpmax[col] < tmpmax[col + s]) { + tmpmax[col] = tmpmax[col + s]; + tmp[col] = tmp[col + s]; + } + } + barrier(); + } + + if (col == 0) { + data_d[row] = D_TYPE(tmp[0]); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp b/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp new file mode 100644 index 00000000000..d9345497c73 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp @@ -0,0 +1,31 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "types.comp" +#include "generic_head.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) buffer D {D_TYPE data_d[];}; + +const uint CHUNK_SIZE = 512; + +void main() { + const uint base = gl_WorkGroupID.x * CHUNK_SIZE; + const uint col = gl_LocalInvocationID.x; + + uint count = 0; + [[unroll]] + for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) { + const uint idx = base + i + col; + if (idx >= p.KX) { + break; + } + count += uint(data_a[idx] == data_b[idx]); + } + + atomicAdd(data_d[0], D_TYPE(count)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp new file mode 100644 index 00000000000..e0214fe7645 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) buffer X {A_TYPE x[];}; +layout (binding = 1) readonly buffer G {A_TYPE grad[];}; +layout (binding = 2) buffer GM {A_TYPE gradm[];}; +layout (binding = 3) buffer GV {A_TYPE gradv[];}; +layout (binding = 4) readonly buffer P {float params[7];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float alpha = params[0]; + const float beta1 = params[1]; + const float beta2 = params[2]; + const float eps = params[3]; + const float wd = params[4]; + const float beta1h = params[5]; + const float beta2h = params[6]; + + const float gi = grad[i]; + const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1); + const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2); + + gradm[i] = gmi; + gradv[i] = gvi; + + const float mh = gmi*beta1h; + const float vh = sqrt(gvi*beta2h) + eps; + + x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp new file mode 100644 index 00000000000..d86279934f1 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp @@ -0,0 +1,37 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + // Destination multi-index (inlined dst_idx) + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; + + // Accumulate from sources + A_TYPE acc = A_TYPE(0); + for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) { + for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) { + for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) { + for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) { + acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00]; + } + } + } + } + + data_d[get_doffset() + d_idx] = D_TYPE(acc); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp new file mode 100644 index 00000000000..72353cc3296 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp @@ -0,0 +1,29 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +#include "types.comp" +#include "generic_binary_head.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index ba9163af27a..3128c3d507a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -443,6 +443,8 @@ void process_shaders() { string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); @@ -452,6 +454,7 @@ void process_shaders() { string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); @@ -501,7 +504,9 @@ void process_shaders() { string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); + string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); @@ -513,6 +518,8 @@ void process_shaders() { string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + for (auto &c : compiles) { c.wait(); } From 06fed447af65f507bd98f5b0dc3ed8770266d147 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 17 Feb 2025 14:03:24 +0100 Subject: [PATCH 44/58] CUDA: use async data loading for FlashAttention (llama/11894) * CUDA: use async data loading for FlashAttention --------- Co-authored-by: Diego Devesa --- ggml/src/ggml-cuda/common.cuh | 21 +- ggml/src/ggml-cuda/cp-async.cuh | 46 ++ ggml/src/ggml-cuda/fattn-common.cuh | 15 +- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 600 +++++++++++++++------------ ggml/src/ggml-cuda/mma.cuh | 483 ++++++++------------- ggml/src/ggml-cuda/mmq.cuh | 278 +++++++------ 6 files changed, 724 insertions(+), 719 deletions(-) create mode 100644 ggml/src/ggml-cuda/cp-async.cuh diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index fd4dcfa941d..4a92d35f9f4 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -41,12 +41,13 @@ #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) #define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons -#define GGML_CUDA_CC_PASCAL 600 -#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products -#define GGML_CUDA_CC_VOLTA 700 -#define GGML_CUDA_CC_TURING 750 -#define GGML_CUDA_CC_AMPERE 800 -#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 +#define GGML_CUDA_CC_PASCAL 600 +#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products +#define GGML_CUDA_CC_VOLTA 700 +#define GGML_CUDA_CC_TURING 750 +#define GGML_CUDA_CC_AMPERE 800 +#define GGML_CUDA_CC_ADA_LOVELACE 890 +#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 // GCN/CNDA, wave size is 64 #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 @@ -199,6 +200,10 @@ typedef float2 dfloat2; #define NEW_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#define CP_ASYNC_AVAILABLE +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) #define FLASH_ATTN_AVAILABLE #endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) @@ -231,6 +236,10 @@ static bool new_mma_available(const int cc) { return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; } +static bool cp_async_available(const int cc) { + return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE; +} + static constexpr __device__ int ggml_cuda_get_physical_warp_size() { #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) return __AMDGCN_WAVEFRONT_SIZE; diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh new file mode 100644 index 00000000000..51aa41e7e60 --- /dev/null +++ b/ggml/src/ggml-cuda/cp-async.cuh @@ -0,0 +1,46 @@ +// Simplified API for asynchronous data loading. + +#include "common.cuh" + +// Copies data from global to shared memory, cg == cache global. +// Both the src and dst pointers must be aligned to 16 bit. +// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int. +// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared. +// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements. +template +static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) { + static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload"); +#ifdef CP_ASYNC_AVAILABLE +#if CUDART_VERSION >= 11040 + if (preload == 256) { + asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else if (preload == 128) { + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else if (preload == 64) { + asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else +#endif // CUDART_VERSION >= 11040 + { + asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } +#else + GGML_UNUSED(dst); + GGML_UNUSED(src); + NO_DEVICE_CODE; +#endif // CP_ASYNC_AVAILABLE +} + +// Makes each thread wait until its asynchronous data copies are done. +// This does NOT provide any additional synchronization. +// In particular, when copying data with multiple warps a call to __syncthreads will be needed. +static __device__ __forceinline__ void cp_async_wait_all() { +#ifdef CP_ASYNC_AVAILABLE + asm volatile("cp.async.wait_all;"); +#else + NO_DEVICE_CODE; +#endif // CP_ASYNC_AVAILABLE +} diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index d40ee2da418..fefbd319baf 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -716,7 +716,9 @@ void launch_fattn( ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); - const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int nsm = ggml_cuda_info().devices[id].nsm; ggml_cuda_pool_alloc K_f16(pool); ggml_cuda_pool_alloc V_f16(pool); @@ -768,13 +770,14 @@ void launch_fattn( dim3 blocks_num; if (parallel_blocks == 0) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. - const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm; - const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total; - const bool short_context = K->ne[1] < 4096; + const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm); + const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves); const int nblocks_stream_k = 2*nsm; - blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k; + const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE; + + blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; blocks_num.y = 1; blocks_num.z = 1; @@ -827,7 +830,7 @@ void launch_fattn( CUDA_CHECK(cudaGetLastError()); if constexpr (parallel_blocks == 0) { - if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles. + if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. const dim3 block_dim_combine(D, 1, 1); const dim3 blocks_num_combine = blocks_num; diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 05bc91a3b8d..d777f5413ed 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1,242 +1,195 @@ #include "common.cuh" +#include "cp-async.cuh" #include "mma.cuh" #include "fattn-common.cuh" -template -static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( - const float2 * const __restrict__ Q_f2, - const half2 * const __restrict__ K_h2, - const half2 * const __restrict__ V_h2, - const half * const __restrict__ maskh, - float2 * const __restrict__ dstk, - float2 * const __restrict__ dstk_fixup, - const float scale, - const float slope, - const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int nb31, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3, - const int jt, - const int kb0_start, - const int kb0_stop) { -#ifdef NEW_MMA_AVAILABLE - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. +using namespace ggml_cuda_mma; - typedef mma_A_I16K8 mma_A; - typedef mma_B_J8K8 mma_B; - typedef mma_C_I16J8 mma_C_KQ; - typedef mma_C_I16J8 mma_C_VKQ; - - static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps"); - constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column. - - static_assert(D % nwarps == 0, "bad D"); - static_assert(KQ_stride % nwarps == 0, "bad KQ_stride"); +typedef tile<16, 8, half2> tile_A; +typedef tile< 8, 8, half2> tile_B; +typedef tile<16, 8, float> tile_C_KQ; +typedef tile<16, 4, half2> tile_C_VKQ; +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) { constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. - extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements. - const int stride_Q = nb01 / sizeof(float2); - const int stride_KV = nb11 / sizeof(half2); - const int stride_mask = nb31 / sizeof(half); + // If cp.async is available, load up to the highest power of 2 in D asynchronously: +#ifdef CP_ASYNC_AVAILABLE + static_assert(D >= 64 && D < 512, "bad D"); + constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128); - mma_B Q_B[D/(2*mma_B::K)]; - mma_C_VKQ VKQ_C[D/mma_C_VKQ::I]; + const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV); - float2 KQ_rowsum = {0.0f, 0.0f}; - float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f}; - float2 KQ_max_scale = {0.0f, 0.0f}; + constexpr int preload = 64; + constexpr int h2_per_chunk = 16/sizeof(half2); + constexpr int chunks_per_row = k0_sync_start / h2_per_chunk; + constexpr int stride_i = WARP_SIZE / chunks_per_row; +#pragma unroll + for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row); + const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk; - // Temporarily load Q data into tile_KV, will be loaded into registers afterwards. - // The loading is done with decreasing granularity for D for better memory bandwidth. - const half2 scale_h2 = make_half2(scale, scale); + cp_async_cg_16(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k); + } +#else + constexpr int k0_sync_start = 0; +#endif // CP_ASYNC_AVAILABLE + static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start"); + + // If D is not a power of 2, the rest is loaded synchronously. + // K/V data is loaded with decreasing granularity for D for better memory bandwidth. + static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); #pragma unroll for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); - const int stride_j = WARP_SIZE / stride_k; + const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; - if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { - break; + if (k0_start == k0_stop || k0_stop <= k0_sync_start) { + continue; } #pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) { - const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - if (jt*ncols + j < ne01) { -#pragma unroll - for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - - const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k]; - tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); - } - } else { #pragma unroll - for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f); - } + tile_KV[i*D2_padded + k] = KV[i*stride_KV + k]; } } } +} - __syncthreads(); - - { - const int j0 = (threadIdx.y / np) * mma_B::J; +template +static __device__ __forceinline__ void flash_attn_ext_f16_iter( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half * const __restrict__ maskh, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne01, + const int ne02, + const int stride_Q, + const int stride_KV, + const int stride_mask, + const int jt, + half2 * const __restrict__ tile_K, + half2 * const __restrict__ tile_V, + const tile_B * const __restrict__ Q_B, + tile_C_VKQ * const __restrict__ VKQ_C, + float2 & KQ_max, + float2 & KQ_rowsum, + const int kb0) { +#ifdef NEW_MMA_AVAILABLE + constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column. + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. -#pragma unroll - for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { - Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded); - } - } + const int k_VKQ_0 = kb0*KQ_stride; + tile_C_KQ KQ_C[KQ_stride/(np*tile_C_KQ::I)]; +#ifdef CP_ASYNC_AVAILABLE + cp_async_wait_all(); __syncthreads(); + flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); +#else + flash_attn_ext_f16_load_tile(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV); + __syncthreads(); +#endif // CP_ASYNC_AVAILABLE - // Iterate over ne11 == previous tokens: - for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) { - const int k_VKQ_0 = kb0*KQ_stride; - mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)]; - - // Load K data into tile with decreasing granularity for D for better memory bandwidth: - static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); -#pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) { - const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - -#pragma unroll - for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) { - const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - - tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ]; - } - } - } - - __syncthreads(); - - // Calculate tile of KQ: + // Calculate tile of KQ: #pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) { - const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I; + for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) { - mma_A K_A; - K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded); - KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]); - } + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) { + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded); + mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, ((tile_B *) Q_B)[k_KQ_0/tile_A::J]); } + } - __syncthreads(); +#ifndef CP_ASYNC_AVAILABLE + __syncthreads(); // Only needed if tile_K == tile_V. +#endif // CP_ASYNC_AVAILABLE - if (use_logit_softcap) { - static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); + if (use_logit_softcap) { + static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) { + for (int i = 0; i < KQ_stride/(np*tile_C_KQ::I); ++i) { #pragma unroll - for (int l = 0; l < mma_C_KQ::ne; ++l) { - KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); - } + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); } } + } - if (maskh) { - static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size"); - static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size"); + if (maskh) { + static_assert(KQ_stride % (np *tile_C_KQ::I) == 0, "bad loop size"); + static_assert(ncols % (nwarps/np*tile_C_KQ::J) == 0, "bad loop size"); #pragma unroll - for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) { - const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I; + for (int i00 = 0; i00 < KQ_stride; i00 += np*tile_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; #pragma unroll - for (int l = 0; l < mma_C_KQ::ne; ++l) { - const int i = i0 + mma_C_KQ::get_i(l); - const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l); + for (int l = 0; l < tile_C_KQ::ne; ++l) { + const int i = i0 + tile_C_KQ::get_i(l); + const int j = (threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l); - KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]); - } + KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]); } } + } - // Calculate softmax for each KQ column using the current max. value. - // The divisor is stored in KQ_rowsum and will be applied at the end. - float2 KQ_max_new = KQ_max; - static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + float2 KQ_max_new = KQ_max; + static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { + for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) { #pragma unroll - for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) { - KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]); - KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]); - } + for (int l0 = 0; l0 < tile_C_KQ::ne; l0 += 2) { + KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]); + KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]); } + } - // Values per KQ column are spread across 8 threads, does not need full warp reduce: + // Values per KQ column are spread across 8 threads, does not need full warp reduce: #pragma unroll - for (int offset = 16; offset > 2; offset >>= 1) { - KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE)); - KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE)); - } - - { - const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y); - KQ_max_scale = make_float2(expf(diff.x), expf(diff.y)); - if (diff.x <= SOFTMAX_FTZ_THRESHOLD) { - KQ_max_scale.x = 0.0f; - } - if (diff.y <= SOFTMAX_FTZ_THRESHOLD) { - KQ_max_scale.y = 0.0f; - } - KQ_max = KQ_max_new; - } + for (int offset = 16; offset > 2; offset >>= 1) { + KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE)); + KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE)); + } - float2 KQ_rowsum_add = make_float2(0.0f, 0.0f); - static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); + float2 KQ_rowsum_add = make_float2(0.0f, 0.0f); + static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { + for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) { #pragma unroll - for (int l = 0; l < mma_C_KQ::ne; ++l) { - const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y; - const float diff = KQ_C[k].x[l] - KQ_max_l; - KQ_C[k].x[l] = expf(diff); - if (diff <= SOFTMAX_FTZ_THRESHOLD) { - KQ_C[k].x[l] = 0.0f; - } + for (int l = 0; l < tile_C_KQ::ne; ++l) { + const float KQ_max_l = l % 2 == 0 ? KQ_max_new.x : KQ_max_new.y; + const float diff = KQ_C[k].x[l] - KQ_max_l; + KQ_C[k].x[l] = expf(diff); - if (l % 2 == 0) { - KQ_rowsum_add.x += KQ_C[k].x[l]; - } else { - KQ_rowsum_add.y += KQ_C[k].x[l]; - } + if (l % 2 == 0) { + KQ_rowsum_add.x += KQ_C[k].x[l]; + } else { + KQ_rowsum_add.y += KQ_C[k].x[l]; } } + } + + { + const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y); + const float2 KQ_max_scale = make_float2(expf(diff.x), expf(diff.y)); + KQ_max = KQ_max_new; // Scale previous KQ_rowsum to account for a potential increase in KQ_max: KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x; @@ -244,60 +197,179 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y); #pragma unroll - for (int i = 0; i < D/mma_C_VKQ::I; ++i) { + for (int i = 0; i < D/tile_C_VKQ::I; ++i) { #pragma unroll - for (int l = 0; l < mma_C_VKQ::ne; ++l) { + for (int l = 0; l < tile_C_VKQ::ne; ++l) { VKQ_C[i].x[l] *= KQ_max_scale_h2; } } + } + + // Convert KQ C tiles into B tiles for VKQ calculation: + tile_B B[KQ_stride/(np*2*tile_B::J)]; + static_assert(KQ_stride % (np*2*tile_B::J) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_stride/(np*2*tile_B::J); ++k) { + B[k] = get_transposed(get_half2(KQ_C[k])); + } - // Convert KQ C tiles into B tiles for VKQ calculation: - mma_B B[KQ_stride/(np*2*mma_B::K)]; - static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size"); +#ifdef CP_ASYNC_AVAILABLE + cp_async_wait_all(); + __syncthreads(); + if (!last_iter) { + flash_attn_ext_f16_load_tile(K_h2 + (k_VKQ_0 + KQ_stride)*stride_KV, tile_K, stride_KV); + } +#else + flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); + __syncthreads(); +#endif // CP_ASYNC_AVAILABLE + + // Calculate VKQ tile: +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) { + static_assert((KQ_stride/2) % (np*tile_A::J) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) { - B[k] = KQ_C[k].to_mma_B(); + for (int k00 = 0; k00 < KQ_stride/2; k00 += np*tile_A::J) { + const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + + tile_A A; + load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); + mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + } + } + +#ifndef CP_ASYNC_AVAILABLE + __syncthreads(); // Only needed if tile_K == tile_V. +#endif // CP_ASYNC_AVAILABLE + +#else + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE +} + +template +static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half * const __restrict__ maskh, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne01, + const int ne02, + const int stride_Q, + const int stride_KV, + const int stride_mask, + const int jt, + const int kb0_start, + const int kb0_stop) { +#ifdef NEW_MMA_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + static_assert(nwarps*tile_B::I % ncols == 0, "bad nwarps"); + constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column. + + static_assert(D % nwarps == 0, "bad D"); + static_assert(KQ_stride % nwarps == 0, "bad KQ_stride"); + + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + + // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements: + extern __shared__ half2 tile_K[]; +#ifdef CP_ASYNC_AVAILABLE + half2 * tile_V = tile_K + KQ_stride*D2_padded; +#else + half2 * tile_V = tile_K; +#endif // CP_ASYNC_AVAILABLE + + tile_B Q_B[D/(2*tile_B::J)]; + tile_C_VKQ VKQ_C[D/tile_C_VKQ::I]; + + float2 KQ_rowsum = {0.0f, 0.0f}; + float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f}; + + // Temporarily load Q data into tile_K, will be loaded into registers afterwards. + // The loading is done with decreasing granularity for D for better memory bandwidth. + const half2 scale_h2 = make_half2(scale, scale); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_j = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + continue; + } + + if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { + break; } - // Load V data into tile with decreasing granularity for D for better memory bandwidth: - static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); #pragma unroll - for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i); - const int i0_stop = D/2 - (D/2) % (1*stride_i); - const int stride_k = WARP_SIZE / stride_i; + for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) { + const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + if (jt*ncols + j < ne01) { #pragma unroll - for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) { - const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i); + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k]; + tile_K[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); + } + } else { #pragma unroll - for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) { - const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i); + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V]; + tile_K[j*D2_padded + k] = make_half2(0.0f, 0.0f); } } } + } - __syncthreads(); + __syncthreads(); - // Calculate VKQ tile: -#pragma unroll - for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) { - static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size"); -#pragma unroll - for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) { - const int k0 = k00 + (threadIdx.y % np)*mma_A::K; + { + const int j0 = (threadIdx.y / np) * tile_B::I; - mma_A A; - A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); - VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]); - } +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded); } + } + + __syncthreads(); + // Preload K data for first iteration when using cp_async: +#ifdef CP_ASYNC_AVAILABLE + flash_attn_ext_f16_load_tile(K_h2 + kb0_start*KQ_stride*stride_KV, tile_K, stride_KV); +#endif // CP_ASYNC_AVAILABLE + + // Iterate over ne11 == previous tokens: + for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { + constexpr bool last_iter = false; + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + } + { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + constexpr bool last_iter = true; + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + } + + // With cp_async there is no __syncthreads at the end of the iter, + // there can be a race condition on shared memory access for combining/writing back results. +#ifdef CP_ASYNC_AVAILABLE + if (nwarps*tile_B::I > KQ_stride) { __syncthreads(); } +#endif // CP_ASYNC_AVAILABLE // Finally, sum up partial KQ rowsums. // The partial sums are spread across 8 threads each, does not need full reduce. @@ -310,26 +382,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Write VKQ accumulators to shared memory in column-major format. // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // Also for np > 1 the combination is done via these values in shared memory. - const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data + const int j_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // j combine write data #pragma unroll - for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { - const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format. + for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. #pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int k = k0 + mma_B::get_k(l); + for (int l = 0; l < tile_B::ne; ++l) { + const int k = k0 + tile_B::get_j(l); - tile_KV[j_cwd*D2_padded + k] = B.x[l]; + tile_K[j_cwd*D2_padded + k] = B.x[l]; } } - const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset - const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta + const int j_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // j combine write meta offset + const int j_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum - if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) { + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. - ((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr; + ((float2 *) tile_K)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr; } __syncthreads(); @@ -337,11 +409,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( static_assert(np == 1 || np == 2 || np == 4, "bad np"); if (np == 1) { // No combination is needed, the meta data can be directly written from registers to VRAM. - if (needs_fixup && threadIdx.x < mma_B::J) { + if (needs_fixup && threadIdx.x < tile_B::I) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; dstk_fixup_meta[j_cwm] = KQ_cmr; } - if (is_fixup && threadIdx.x < mma_B::J) { + if (is_fixup && threadIdx.x < tile_B::I) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[j_cwm] = KQ_cmr; } @@ -350,42 +422,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Warps with threadIdx.y % np != 0 must NOT return early. // All threads must return simultaneously to avoid race conditions with work on the next tile. - float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2; + float * meta_j = (float *) tile_K + (threadIdx.y*tile_B::I + threadIdx.x)*D2_padded + D/2; float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp. - if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { + if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) { KQ_cm = meta_j[0]; } float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps. #pragma unroll - for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { + for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) { KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); } const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp. float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps. - if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { + if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) { KQ_crs = KQ_cms*meta_j[1]; } #pragma unroll - for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { + for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) { KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } // Write back combined meta data: - if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { - meta_j[0] = KQ_cmn; // Combined max. KQ values. - meta_j[1] = KQ_crs; // Combined KQ rowsums. - meta_j[2] = KQ_cms; // KQ max scales per parallel warp. + if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) { + *((float2 *) meta_j) = make_float2(KQ_cms, KQ_crs); // Combined KQ max scale + rowsum. } - if (needs_fixup && threadIdx.x < mma_B::J) { + if (needs_fixup && threadIdx.x < tile_B::I) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; - dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } - if (is_fixup && threadIdx.x < mma_B::J) { + if (is_fixup && threadIdx.x < tile_B::I) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; - dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } } @@ -404,6 +474,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int k0_stop = D/2 - (D/2) % (1*stride_k); const int stride_j = WARP_SIZE / stride_k; + if (k0_start == k0_stop) { + continue; + } + if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { break; } @@ -411,12 +485,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) { const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J; + const int j_tile_K = (j_dst/tile_B::I)*(np*tile_B::I) + j_dst % tile_B::I; if (!is_fixup && jt*ncols + j_dst >= ne01) { continue; } - const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2; + const float * meta_j = (const float *) tile_K + j_tile_K*D2_padded + D/2; #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); @@ -424,8 +498,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float2 dstk_val = make_float2(0.0f, 0.0f); #pragma unroll for (int ip = 0; ip < np; ++ip) { - const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2]; - const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]); + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*tile_B::I*D2_padded + 0]; + const float2 dstk_val_add = __half22float2(tile_K[(j_tile_K + ip*tile_B::I)*D2_padded + k]); dstk_val.x += dstk_val_add.x*KQ_crs; dstk_val.y += dstk_val_add.y*KQ_crs; } @@ -450,7 +524,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); } #else - NO_DEVICE_CODE; + NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } @@ -494,6 +568,11 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { +#ifndef NEW_MMA_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // NEW_MMA_AVAILABLE + // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -504,6 +583,10 @@ static __global__ void flash_attn_ext_f16( const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const int stride_Q = nb01 / sizeof(float2); + const int stride_KV = nb11 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); + const int iter_k = ne11 / KQ_stride; const int iter_j = (ne01 + (ncols - 1)) / ncols; @@ -535,14 +618,12 @@ static __global__ void flash_attn_ext_f16( constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, - ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, - jt, kb0_start, kb0_stop); + ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, - ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, - jt, kb0_start, kb0_stop); + ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop); } kbc += iter_k; @@ -571,24 +652,27 @@ static __global__ void flash_attn_ext_f16( constexpr bool needs_fixup = false; flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, - ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, - jt, kb0_start, kb0_stop); + ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop); } template void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - typedef mma_A_I16K8 mma_A; - typedef mma_B_J8K8 mma_B; + typedef tile<16, 8, half2> tile_A; + typedef tile< 8, 8, half2> tile_B; - static_assert(D % mma_B::K == 0, "bad D"); - static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block"); + static_assert(D % tile_B::J == 0, "bad D"); + static_assert(cols_per_block % tile_B::I == 0, "bad cols_per_block"); const ggml_tensor * KQV = dst; + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + constexpr int KQ_stride = D <= 128 ? 64 : 32; + constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ? + cols_per_block/tile_B::J * KQ_stride/tile_A::I : (cols_per_block <= 8 ? 4 : 8); - constexpr int KQ_stride = D <= 128 ? 64 : 32; - constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ? - cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8); - constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half); + const int nrows_KQ = cp_async_available(cc) ? 2*KQ_stride : KQ_stride; + const int nrows_combine = nwarps*tile_B::J; + const size_t nbytes_shared = std::max(nrows_KQ, nrows_combine) * (D + 8) * sizeof(half); float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index bbc0a35ae56..0a5656e4cb3 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -4,11 +4,12 @@ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction // // Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C. -// A is a row-major matrix with shape I x K. -// B is a column-major matrix with shape K x J. -// C is a column-major matrix with shape I x J. -// Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements. -// The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile. +// A is a row-major matrix with shape M x K. +// B is a column-major matrix with shape K x N. +// C is a column-major matrix with shape M x N. +// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns. +// Note that J is measured in physical 32 bit elements instead of logical elements. +// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile. // All matrix tiles have ne physical 32 bit elements per warp. // // As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. @@ -23,7 +24,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { #ifdef NEW_MMA_AVAILABLE asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" - : "+r"(ret) : "r"(x)); + : "=r"(ret) : "r"(x)); #else NO_DEVICE_CODE; #endif // defined(NEW_MMA_AVAILABLE) @@ -52,407 +53,267 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { #endif // CUDART_VERSION >= 11080 +static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) { + half2 ret; + *((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x)); + return ret; +} -template -struct mma_A_I16K4 { - static_assert(sizeof(T) == 4, "bad type size"); - - static constexpr int I = 16; - static constexpr int K = 4; - static constexpr int ne = 2; - - T x[ne]; +namespace ggml_cuda_mma { + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr int ne = I * J / WARP_SIZE; + T x[ne] = {0}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && (J == 4 || J == 8)) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 8) { + return (l / 2) * 8 + threadIdx.x / 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } - static __device__ __forceinline__ int get_i(const int l) { - const int ret = (l%2) * (I/2) + threadIdx.x / K; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < I); - return ret; - } + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return threadIdx.x % 4; + } else if constexpr (I == 8 && J == 8) { + return 4 * l + threadIdx.x % 4; + } else if constexpr (I == 16 && J == 8) { + return 2 * (threadIdx.x % 4) + l % 2; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + }; + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr int ne = I * J / WARP_SIZE; + half2 x[ne] = {{0.0f, 0.0f}}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 8) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 4) { + return l * 8 + threadIdx.x / 4; + } else if constexpr (I == 16 && J == 8) { + return (l % 2) * 8 + threadIdx.x / 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } - static __device__ __forceinline__ int get_k(const int /* l */) { - const int ret = threadIdx.x % K; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < K); - return ret; - } + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 8) { + return l * 4 + threadIdx.x % 4; + } else if constexpr (I == 16 && J == 4) { + return threadIdx.x % 4; + } else if constexpr (I == 16 && J == 8) { + return (l / 2) * 4 + threadIdx.x % 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + }; - __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { + template + static __device__ __forceinline__ tile get_half2(const tile & tile_float) { + tile ret; #pragma unroll - for (int l = 0; l < ne; ++l) { - x[l] = xs0[get_i(l)*stride + get_k(l)]; + for (int l0 = 0; l0 < tile_float.ne; l0 += 2) { + ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); } - } - - __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { -#ifdef NEW_MMA_AVAILABLE - int * xi = (int *) x; - const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride; - asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" - : "+r"(xi[0]), "+r"(xi[1]) - : "l"(xs)); -#else - load_generic(xs0, stride); -#endif // NEW_MMA_AVAILABLE - } -}; - -template -struct mma_A_I16K8 { - static_assert(sizeof(T) == 4, "bad type size"); - - static constexpr int I = 16; - static constexpr int K = 8; - static constexpr int ne = 4; - - T x[ne]; - - static __device__ __forceinline__ int get_i(const int l) { - const int ret = (l%2) * (I/2) + threadIdx.x / (K/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < I); return ret; } - static __device__ __forceinline__ int get_k(const int l) { - const int ret = (l/2) * (K/2) + threadIdx.x % (K/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < K); + static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) { + tile<8, 8, half2> ret; + ret.x[0] = ggml_cuda_movmatrix(t.x[0]); + ret.x[1] = ggml_cuda_movmatrix(t.x[1]); + return ret; } - __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { + template + static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { #pragma unroll - for (int l = 0; l < ne; ++l) { - x[l] = xs0[get_i(l)*stride + get_k(l)]; + for (int l = 0; l < t.ne; ++l) { + t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; } } - __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { + template + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) { #ifdef NEW_MMA_AVAILABLE - int * xi = (int * ) x; - const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); - asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" - : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + int * xi = (int *) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J; + asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); #else - GGML_UNUSED(xs0); - GGML_UNUSED(stride); - NO_DEVICE_CODE; + load_generic(t, xs0, stride); #endif // NEW_MMA_AVAILABLE } - __device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) { + template + static __device__ __forceinline__ void load_ldmatrix( + tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { #ifdef NEW_MMA_AVAILABLE - int * xi = (int * ) x; - const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); - asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" - : "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3]) + int * xi = (int *) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; + asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); #else - GGML_UNUSED(xs0); - GGML_UNUSED(stride); - NO_DEVICE_CODE; + load_generic(xs0, stride); #endif // NEW_MMA_AVAILABLE } - __device__ __forceinline__ void transpose() { - int * xi = (int *) x; - xi[0] = ggml_cuda_movmatrix(xi[0]); - - const int tmp = ggml_cuda_movmatrix(xi[1]); - xi[1] = ggml_cuda_movmatrix(xi[2]); - xi[2] = tmp; - - xi[3] = ggml_cuda_movmatrix(xi[3]); - } -}; - -template -struct mma_B_J8K4 { - static_assert(sizeof(T) == 4, "bad type size"); - - static constexpr int J = 8; - static constexpr int K = 4; - static constexpr int ne = 1; - - T x[ne]; - - static __device__ __forceinline__ int get_j(const int /* l */) { - const int ret = threadIdx.x / K; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < J); - return ret; - } - - static __device__ __forceinline__ int get_k(const int /* l */) { - const int ret = threadIdx.x % K; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < K); - return ret; - } - - __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { -#pragma unroll - for (int l = 0; l < ne; ++l) { - x[l] = xs0[get_j(l)*stride + get_k(l)]; - } - } - - __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { + template + static __device__ __forceinline__ void load_ldmatrix( + tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { #ifdef NEW_MMA_AVAILABLE - int * xi = (int *) x; - const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride; - asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];" - : "+r"(xi[0]) : "l"(xs)); + int * xi = (int * ) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) + : "l"(xs)); #else - load_generic(xs0, stride); + load_generic(t, xs0, stride); #endif // NEW_MMA_AVAILABLE } -}; - -template -struct mma_B_J8K8 { - static_assert(sizeof(T) == 4, "bad type size"); - - static constexpr int J = 8; - static constexpr int K = 8; - static constexpr int ne = 2; - T x[ne]; - - static __device__ __forceinline__ int get_j(const int /* l */) { - const int ret = threadIdx.x / (K/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < J); - return ret; - } - - static __device__ __forceinline__ int get_k(const int l) { - const int ret = l * (K/2) + threadIdx.x % (K/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < K); - return ret; - } - - __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { -#pragma unroll - for (int l = 0; l < ne; ++l) { - x[l] = xs0[get_j(l)*stride + get_k(l)]; - } - } - - __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { + template + static __device__ __forceinline__ void load_ldmatrix_trans( + tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { #ifdef NEW_MMA_AVAILABLE - int * xi = (int *) x; - const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K; - asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" - : "+r"(xi[0]), "+r"(xi[1]) + int * xi = (int * ) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) : "l"(xs)); #else - load_generic(xs0, stride); + GGML_UNUSED(t); + GGML_UNUSED(xs0); + GGML_UNUSED(stride); + NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } -}; - -template -struct mma_C_I16J8 {}; - -template <> -struct mma_C_I16J8 { - static constexpr int I = 16; - static constexpr int J = 8; - static constexpr int ne = 4; - int x[ne] = {0}; - - static __device__ __forceinline__ int get_i(const int l) { - const int ret = (l/2) * (I/2) + threadIdx.x / (J/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < I); - return ret; - } - - static __device__ __forceinline__ int get_j(const int l) { - const int ret = 2 * (threadIdx.x % (J/2)) + l%2; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < J); - return ret; - } - - __device__ __forceinline__ void mma(const mma_A_I16K4 & mma_A, const mma_B_J8K4 & mma_B) { + static __device__ __forceinline__ void mma( + tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) { #ifdef NEW_MMA_AVAILABLE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" - : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) - : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0])); + : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0])); #else // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead: asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[0]), "+r"(x[1]) - : "r"(mma_A.x[0]), "r"(mma_B.x[0])); + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[0]), "r"(B.x[0])); asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[2]), "+r"(x[3]) - : "r"(mma_A.x[1]), "r"(mma_B.x[0])); + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[1]), "r"(B.x[0])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(mma_A); - GGML_UNUSED(mma_B); + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } - __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { + static __device__ __forceinline__ void mma( + tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) { #ifdef NEW_MMA_AVAILABLE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" - : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) - : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1])); + : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1])); #else // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead: asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[0]), "+r"(x[1]) - : "r"(mma_A.x[0]), "r"(mma_B.x[0])); + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[0]), "r"(B.x[0])); asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[2]), "+r"(x[3]) - : "r"(mma_A.x[1]), "r"(mma_B.x[0])); + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[1]), "r"(B.x[0])); asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[0]), "+r"(x[1]) - : "r"(mma_A.x[2]), "r"(mma_B.x[1])); + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[2]), "r"(B.x[1])); asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[2]), "+r"(x[3]) - : "r"(mma_A.x[3]), "r"(mma_B.x[1])); + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[3]), "r"(B.x[1])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(mma_A); - GGML_UNUSED(mma_B); + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } -}; - -template <> -struct mma_C_I16J8 { - static constexpr int I = 16; - static constexpr int J = 4; - static constexpr int ne = 2; - - half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}}; - - static __device__ __forceinline__ int get_i(const int l) { - const int ret = l * (I/2) + threadIdx.x / J; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < I); - return ret; - } - static __device__ __forceinline__ int get_j(const int /* l */) { - const int ret = threadIdx.x % J; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < J); - return ret; - } - - __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { + static __device__ __forceinline__ void mma( + tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { #ifdef NEW_MMA_AVAILABLE - int * Axi = (int *) mma_A.x; - int * Bxi = (int *) mma_B.x; - int * xi = (int *) x; + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" - : "+r"(xi[0]), "+r"(xi[1]) + : "+r"(Dxi[0]), "+r"(Dxi[1]) : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); #else // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" - : "+r"(xi[0]), "+r"(xi[1]) + : "+r"(Dxi[0]), "+r"(Dxi[1]) : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" - : "+r"(xi[0]), "+r"(xi[1]) + : "+r"(Dxi[0]), "+r"(Dxi[1]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(mma_A); - GGML_UNUSED(mma_B); + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } - __device__ __forceinline__ mma_B_J8K8 to_mma_B() { - mma_B_J8K8 mma_B; - - int * xi = (int *) x; - int * Bxi = (int *) mma_B.x; - Bxi[0] = ggml_cuda_movmatrix(xi[0]); - Bxi[1] = ggml_cuda_movmatrix(xi[1]); - - return mma_B; - } -}; - -template <> -struct mma_C_I16J8 { - static constexpr int I = 16; - static constexpr int J = 8; - static constexpr int ne = 4; - - float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f}; - - static __device__ __forceinline__ int get_i(const int l) { - const int ret = (l/2) * (I/2) + threadIdx.x / (J/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < I); - return ret; - } - - static __device__ __forceinline__ int get_j(const int l) { - const int ret = 2 * (threadIdx.x % (J/2)) + l%2; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < J); - return ret; - } - - __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { + static __device__ __forceinline__ void mma( + tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { #ifdef NEW_MMA_AVAILABLE - int * Axi = (int *) mma_A.x; - int * Bxi = (int *) mma_B.x; - int * xi = (int *) x; + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" - : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); #else // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" - : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" - : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(mma_A); - GGML_UNUSED(mma_B); + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } - __device__ __forceinline__ mma_B_J8K8 to_mma_B() { - mma_B_J8K8 mma_B; - mma_B.x[0] = make_half2(x[0], x[1]); - mma_B.x[1] = make_half2(x[2], x[3]); - - int * Bxi = (int *) mma_B.x; - Bxi[0] = ggml_cuda_movmatrix(Bxi[0]); - Bxi[1] = ggml_cuda_movmatrix(Bxi[1]); - - return mma_B; - } - - __device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) { -#pragma unroll - for (int l = 0; l < ne; ++l) { - x[l] = xs0[get_j(l)*stride + get_i(l)]; - } - } -}; +} diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 5391542086c..0451c65f302 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -7,6 +7,8 @@ #include #include +using namespace ggml_cuda_mma; + #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. #define MMQ_ITER_K 256 #define MMQ_NWARPS 8 @@ -647,15 +649,15 @@ template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - typedef mma_A_I16K8 mma_A; - typedef mma_B_J8K8 mma_B; - typedef mma_C_I16J8 mma_C; + typedef tile<16, 8, int> tile_A; + typedef tile< 8, 8, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + 2*WARP_SIZE; @@ -663,8 +665,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const float * y_df = (const float *) y; const half2 * y_ds = (const half2 *) y; - mma_A A[ntx][WARP_SIZE/QI8_0]; - float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0]; + tile_A A[ntx][WARP_SIZE/QI8_0]; + float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; @@ -674,12 +676,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { const int k0 = k00 + k01; - A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { @@ -691,17 +693,17 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { - mma_B B; - float dB[mma_C::ne/2]; + tile_B B; + float dB[tile_C::ne/2]; - B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; @@ -712,12 +714,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma(A[n][k01/QI8_0], B); + tile_C C; + mma(C, A[n][k01/QI8_0], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2]; + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2]; } } } @@ -758,23 +760,23 @@ template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - typedef mma_A_I16K8 mma_A; - typedef mma_B_J8K8 mma_B; - typedef mma_C_I16J8 mma_C; + typedef tile<16, 8, int> tile_A; + typedef tile< 8, 8, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE; const int * y_qs = (const int *) y + 4; const half2 * y_dm = (const half2 *) y; - mma_A A[ntx][WARP_SIZE/QI8_1]; - float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1]; + tile_A A[ntx][WARP_SIZE/QI8_1]; + float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; @@ -784,12 +786,12 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { const int k0 = k00 + k01; - A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { @@ -801,30 +803,30 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { - mma_B B; - float2 dsB[mma_C::ne/2]; + tile_B B; + float2 dsB[tile_C::ne/2]; - B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); } #pragma unroll for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma(A[n][k01/QI8_1], B); + tile_C C; + mma(C, A[n][k01/QI8_1], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l]; - sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y; + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y; } } } @@ -868,26 +870,26 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef NEW_MMA_AVAILABLE - typedef mma_A_I16K4 mma_A; - typedef mma_A_I16K8 mma_A_K8; - typedef mma_B_J8K4 mma_B; - typedef mma_C_I16J8 mma_C; + typedef tile<16, 4, int> tile_A; + typedef tile<16, 8, int> tile_A_8; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + WARP_SIZE*2; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); - mma_A A[ntx][8]; - float dA[ntx][mma_C::ne/2][8]; + tile_A A[ntx][8]; + float dA[ntx][tile_C::ne/2][8]; #pragma unroll for (int n = 0; n < ntx; ++n) { @@ -895,12 +897,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { const int k0 = k00 + k01; - ((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) { @@ -912,32 +914,32 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { - mma_B B[2]; - float dB[mma_C::ne/2]; + tile_B B[2]; + float dB[tile_C::ne/2]; // Here load_generic is faster than load_ldmatrix. - B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); - B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; } #pragma unroll for (int n = 0; n < ntx; ++n) { - mma_C C[2]; - C[0].mma(A[n][k01/4 + 0], B[0]); - C[1].mma(A[n][k01/4 + 1], B[1]); + tile_C C[2]; + mma(C[0], A[n][k01/4 + 0], B[0]); + mma(C[1], A[n][k01/4 + 1], B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]); + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]); } } } @@ -1056,27 +1058,27 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef NEW_MMA_AVAILABLE - typedef mma_A_I16K4 mma_A; - typedef mma_A_I16K8 mma_A_K8; - typedef mma_B_J8K4 mma_B; - typedef mma_C_I16J8 mma_C; + typedef tile<16, 4, int> tile_A; + typedef tile<16, 8, int> tile_A_8; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); - mma_A A[ntx][8]; - float dA[ntx][mma_C::ne/2][8]; - float mA[ntx][mma_C::ne/2][8]; + tile_A A[ntx][8]; + float dA[ntx][tile_C::ne/2][8]; + float mA[ntx][tile_C::ne/2][8]; #pragma unroll for (int n = 0; n < ntx; ++n) { @@ -1084,15 +1086,15 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { const int k0 = k00 + k01; - ((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); } } #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) { @@ -1107,58 +1109,58 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - float2 dB[mma_C::ne/2]; + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + float2 dB[tile_C::ne/2]; #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); } #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { - mma_B B[2]; + tile_B B[2]; // Here load_generic is faster than load_ldmatrix. - B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); - B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); - mma_C Cm[2]; + tile_C Cm[2]; if (k01 >= WARP_SIZE * 3/4) { - mma_A A1; + tile_A A1; A1.x[0] = 0x01010101; A1.x[1] = 0x01010101; - Cm[0].mma(A1, B[0]); - Cm[1].mma(A1, B[1]); + mma(Cm[0], A1, B[0]); + mma(Cm[1], A1, B[1]); } #pragma unroll for (int n = 0; n < ntx; ++n) { - mma_C Cd[2]; + tile_C Cd[2]; - Cd[0].mma(A[n][k01/4 + 0], B[0]); - Cd[1].mma(A[n][k01/4 + 1], B[1]); + mma(Cd[0], A[n][k01/4 + 0], B[0]); + mma(Cd[1], A[n][k01/4 + 1], B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { + for (int l = 0; l < tile_C::ne; ++l) { float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1]; if (k01 >= WARP_SIZE * 3/4) { tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1]; } - sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y); } } } #pragma unroll for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) { - float2 sB[mma_C::ne/2]; + float2 sB[tile_C::ne/2]; #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); } @@ -1166,9 +1168,9 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x; - sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y; + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y; } } } @@ -1708,15 +1710,15 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef NEW_MMA_AVAILABLE - typedef mma_A_I16K4 mma_A; - typedef mma_B_J8K4 mma_B; - typedef mma_C_I16J8 mma_C; + typedef tile<16, 4, int> tile_A; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + WARP_SIZE*2; @@ -1724,11 +1726,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); - mma_A A[ntx][8]; - int scA[ntx][mma_C::ne/2][8]; - float dA[ntx][mma_C::ne/2]; + tile_A A[ntx][8]; + int scA[ntx][tile_C::ne/2][8]; + float dA[ntx][tile_C::ne/2]; #pragma unroll for (int n = 0; n < ntx; ++n) { @@ -1736,8 +1738,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { const int k0 = k00 + k01; - A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); - A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K); + load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); + load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K); } #pragma unroll @@ -1745,8 +1747,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int k0 = k00 + k01; #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16]; const int8_t * sc = (const int8_t *) &sc_packed; @@ -1759,41 +1761,41 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K]; } } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - float tmp[ntx][mma_C::ne] = {{0.0f}}; + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + float tmp[ntx][tile_C::ne] = {{0.0f}}; #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { - mma_B B[2]; - float dB[mma_C::ne/2]; + tile_B B[2]; + float dB[tile_C::ne/2]; // Here load_generic is faster than load_ldmatrix. - B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); - B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K); + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K); #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; } #pragma unroll for (int n = 0; n < ntx; ++n) { - mma_C C[2]; - C[0].mma(A[n][k01/4 + 0], B[0]); - C[1].mma(A[n][k01/4 + 1], B[1]); + tile_C C[2]; + mma(C[0], A[n][k01/4 + 0], B[0]); + mma(C[1], A[n][k01/4 + 1], B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { + for (int l = 0; l < tile_C::ne; ++l) { tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2]; } } @@ -1802,8 +1804,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2]; + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2]; } } } @@ -2312,36 +2314,36 @@ template static __device__ __forceinline__ void mmq_write_back_mma( const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) { - typedef mma_C_I16J8 mma_C; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I); + const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); #ifdef NEW_MMA_AVAILABLE - static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); + static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); #endif // NEW_MMA_AVAILABLE #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne; ++l) { + const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l); if (j > j_max) { continue; } - const int i = i0 + n*mma_C::I + mma_C::get_i(l); + const int i = i0 + n*tile_C::I + tile_C::get_i(l); if (need_check && i > i_max) { continue; } - dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l]; + dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l]; } } } From 38542f339404f73a33ac5c7eefb3e6564a41d4d7 Mon Sep 17 00:00:00 2001 From: Prashant Vithule <119530321+Vithulep@users.noreply.github.com> Date: Thu, 20 Feb 2025 15:38:32 +0530 Subject: [PATCH 45/58] ggml: aarch64: implement SVE kernels for q3_K_q8_K vector dot (llama/11917) * Added SVE Implementation for Q3_K Kernel in ggml-cpu-quants.c file * Improved Formating of code in ggml-cpu-quants.c file * style : minor fixes * style : less whitespaces * style : ptr spaceing --------- Co-authored-by: vithulep Co-authored-by: Georgi Gerganov --- ggml/src/ggml-cpu/ggml-cpu-quants.c | 177 +++++++++++++++++++++++++++- 1 file changed, 176 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 0315dc2575e..14ba288fe19 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -5112,7 +5112,182 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r const int nb = n / QK_K; -#ifdef __ARM_NEON +#if defined(__ARM_FEATURE_SVE) + + uint32_t utmp[4]; + + const int8_t m32 = 32; + const int vector_length = svcntb()*8; + const svuint8_t m3b_sv = svdup_n_u8(0x3); + const svint32_t vzero_sv = svdup_n_s32(0); + + const svuint8_t m0_sv = svdup_n_u8(1); + const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1); + const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2); + const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3); + svbool_t pred_s32 = svnot_b_z (svptrue_b32(), svptrue_pat_b32(SV_VL4)); + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3_sv = x[i].qs; + const uint8_t * restrict qh_sv = x[i].hmask; + const int8_t * restrict q8_sv = y[i].qs; + + // Set up scales + uint32_t * aux = &x[i].scales; + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + + for (int j = 0; j < 16; ++j) scale[j] -= m32; + + switch (vector_length) { + case 128: + { + svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv); + svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16); + svuint8_t q3h_sv; + + svint32_t sumi1_1 = svdup_n_s32(0); + svint8_t q3bytes_sv; + + for (int j = 0; j < QK_K/128; ++j) { + + const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16; + const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16; + svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0])); + + q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1])); + + q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2])); + + q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3])); + + + scale += 4; + q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0])); + + q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1])); + + + q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2])); + + q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3])); + + if (j == 0) { + qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4); + qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4); + } + + scale += 4; + } + + sum += d * (svaddv_s32(svptrue_b32(), sumi1_1)); + } break; + case 256: + case 512: + { + svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv); + svuint8_t q3h_sv; + + svint32_t sumi1_1 = svdup_n_s32(0); + svint8_t q3bytes_sv; + + for (int j = 0; j < QK_K/128; ++j) { + + const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32; + svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2); + q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + + svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1])); + sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1); + + q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1); + q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3])); + sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1); + + scale += 4; + q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv); + q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1])); + sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1); + + q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1); + q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3])); + sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1); + + if (j == 0) { + qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4); + } + + scale += 4; + } + + sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1)); + } break; + default: + assert(false && "Unsupported vector length"); + break; + } + } + *s = sum; + +#elif __ARM_NEON uint32_t aux[3]; uint32_t utmp[4]; From bb7026d12690a47c835af930faceb4dd1e3f32e0 Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Thu, 20 Feb 2025 14:06:51 +0100 Subject: [PATCH 46/58] ggml-cpu: Add CPU backend support for KleidiAI library (llama/11390) * ggml-cpu: Add CPU backend support for KleidiAI library * Add environmental variable GGML_KLEIDIAI_SME * Add support for multithread LHS conversion * Switch kernel selection order to dotprod and i8mm * updates for review comments * More updates for review comments * Reorganize and rename KleidiAI files * Move ggml-cpu-traits.h to source file * Update cmake for SME build and add alignment for SME * Remove append GGML_USE_CPU_KLEIDIAI to the GGML_CDEF_PUBLIC list --- ggml/CMakeLists.txt | 1 + ggml/include/ggml-cpu.h | 1 + ggml/src/ggml-cpu/CMakeLists.txt | 102 ++++++++- ggml/src/ggml-cpu/ggml-cpu.c | 33 ++- ggml/src/ggml-cpu/ggml-cpu.cpp | 16 ++ ggml/src/ggml-cpu/kleidiai/kernels.cpp | 259 +++++++++++++++++++++ ggml/src/ggml-cpu/kleidiai/kernels.h | 61 +++++ ggml/src/ggml-cpu/kleidiai/kleidiai.cpp | 287 ++++++++++++++++++++++++ ggml/src/ggml-cpu/kleidiai/kleidiai.h | 17 ++ 9 files changed, 767 insertions(+), 10 deletions(-) create mode 100644 ggml/src/ggml-cpu/kleidiai/kernels.cpp create mode 100644 ggml/src/ggml-cpu/kleidiai/kernels.h create mode 100644 ggml/src/ggml-cpu/kleidiai/kleidiai.cpp create mode 100644 ggml/src/ggml-cpu/kleidiai/kleidiai.h diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 75b5ea3b439..fc5eac151b9 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -102,6 +102,7 @@ endif() option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF) option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON) +option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF) option(GGML_AVX "ggml: enable AVX" ${INS_ENB}) option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF) option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB}) diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index d23c6b262e2..9b8a697546e 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -95,6 +95,7 @@ extern "C" { GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void); GGML_BACKEND_API int ggml_cpu_has_sve (void); GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes + GGML_BACKEND_API int ggml_cpu_has_sme (void); // other GGML_BACKEND_API int ggml_cpu_has_riscv_v (void); GGML_BACKEND_API int ggml_cpu_has_vsx (void); diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 26533e512ae..826d65cece0 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -111,14 +111,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name) function(check_arm_feature tag code) set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}") - check_cxx_source_runs( - "${code}" - GGML_MACHINE_SUPPORTS_${tag} - ) + check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag}) if (GGML_MACHINE_SUPPORTS_${tag}) set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE) else() - set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE) + set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+no${tag}") + check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag}) + if (GGML_MACHINE_SUPPORTS_no${tag}) + set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE) + endif() endif() set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) endfunction() @@ -126,6 +127,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) check_arm_feature(dotprod "#include \nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }") check_arm_feature(i8mm "#include \nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }") check_arm_feature(sve "#include \nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }") + check_arm_feature(sme "#include \n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }") list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}") else() @@ -150,7 +152,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (ARM_FEATURE_RESULT) message(WARNING "Failed to get ARM features") else() - foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC) + foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME) string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos) if (NOT ${feature_pos} EQUAL -1) message(STATUS "ARM feature ${feature} enabled") @@ -312,6 +314,94 @@ function(ggml_add_cpu_backend_variant_impl tag_name) target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_AARCH64) endif() + if (GGML_CPU_KLEIDIAI) + message(STATUS "Using KleidiAI optimized kernels if applicable") + + # Disable the KleidiAI tests + set(KLEIDIAI_BUILD_TESTS OFF) + + # Fetch KleidiAI sources: + include(FetchContent) + set(KLEIDIAI_COMMIT_TAG "v1.3.0") + set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") + set(KLEIDIAI_ARCHIVE_MD5 "060bd2dc64642b091f461cc8dd7426d9") + + if (POLICY CMP0135) + cmake_policy(SET CMP0135 NEW) + endif() + + FetchContent_Declare(KleidiAI_Download + URL ${KLEIDIAI_DOWNLOAD_URL} + DOWNLOAD_EXTRACT_TIMESTAMP NEW + URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}) + + FetchContent_MakeAvailable(KleidiAI_Download) + FetchContent_GetProperties(KleidiAI_Download + SOURCE_DIR KLEIDIAI_SRC + POPULATED KLEIDIAI_POPULATED) + + if (NOT KLEIDIAI_POPULATED) + message(FATAL_ERROR "KleidiAI source downloaded failed.") + endif() + + add_compile_definitions(GGML_USE_CPU_KLEIDIAI) + + # Remove kleidiai target after fetching it + if (TARGET kleidiai) + set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE) + endif() + + list(APPEND GGML_CPU_SOURCES + ggml-cpu/kleidiai/kleidiai.cpp + ggml-cpu/kleidiai/kernels.cpp + ggml-cpu/kleidiai/kleidiai.h + ggml-cpu/kleidiai/kernels.h + ) + + # KleidiAI + include_directories( + ${KLEIDIAI_SRC}/ + ${KLEIDIAI_SRC}/kai/ + ${KLEIDIAI_SRC}/kai/ukernels/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/) + + set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}") + if (NOT ARCH_FLAGS_TEMP) + string(REGEX MATCH "-march=[^ ]+" ARCH_FLAGS_TEMP "${CMAKE_C_FLAGS}") + endif() + string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED) + string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED) + string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED) + + set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS}) + + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c) + + if (NOT DOTPROD_ENABLED MATCHES -1) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c) + endif() + + if (NOT I8MM_ENABLED MATCHES -1) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c) + endif() + + if (NOT SME_ENABLED MATCHES -1) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c) + list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c) + set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2") + endif() + + set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}") + list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES}) + endif() + message(STATUS "Adding CPU backend variant ${GGML_CPU_NAME}: ${ARCH_FLAGS} ${ARCH_DEFINITIONS}") target_sources(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_SOURCES}) target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS}) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index dbef5df2111..f27b981715a 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -112,7 +112,8 @@ struct ggml_arm_arch_features_type { int has_i8mm; int has_sve; int sve_cnt; -} ggml_arm_arch_features = {-1, -1, -1, -1, 0}; + int has_sme; +} ggml_arm_arch_features = {-1, -1, -1, -1, 0, -1}; #endif @@ -2381,15 +2382,20 @@ bool ggml_is_numa(void) { #define HWCAP2_I8MM (1 << 13) #endif +#if !defined(HWCAP2_SME) +#define HWCAP2_SME (1 << 23) +#endif + static void ggml_init_arm_arch_features(void) { #if defined(__linux__) && defined(__aarch64__) uint32_t hwcap = getauxval(AT_HWCAP); uint32_t hwcap2 = getauxval(AT_HWCAP2); - ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD); + ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD); ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP); - ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM); - ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE); + ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM); + ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE); + ggml_arm_arch_features.has_sme = !!(hwcap2 & HWCAP2_SME); #if defined(__ARM_FEATURE_SVE) ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); @@ -2412,6 +2418,11 @@ static void ggml_init_arm_arch_features(void) { } ggml_arm_arch_features.has_i8mm = oldp; + if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) != 0) { + oldp = 0; + } + ggml_arm_arch_features.has_sme = oldp; + ggml_arm_arch_features.has_sve = 0; ggml_arm_arch_features.sve_cnt = 0; #else @@ -2435,6 +2446,12 @@ static void ggml_init_arm_arch_features(void) { ggml_arm_arch_features.has_sve = 0; ggml_arm_arch_features.sve_cnt = 0; #endif + +#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_SME2) + ggml_arm_arch_features.has_sme = 1; +#else + ggml_arm_arch_features.has_sme = 0; +#endif #endif } #endif @@ -14442,6 +14459,14 @@ int ggml_cpu_get_sve_cnt(void) { #endif } +int ggml_cpu_has_sme(void) { +#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME) + return ggml_arm_arch_features.has_sme; +#else + return 0; +#endif +} + void ggml_cpu_init(void) { // needed to initialize f16 tables { diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index be4eadcd021..d0ae10ee376 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -14,6 +14,10 @@ #include "ggml-cpu-hbm.h" #endif +#ifdef GGML_USE_CPU_KLEIDIAI +#include "kleidiai/kleidiai.h" +#endif + #if defined(__APPLE__) #include #include @@ -39,6 +43,12 @@ std::vector& ggml_backend_cpu_get_extra_buffers_type } #endif +#ifdef GGML_USE_CPU_KLEIDIAI + if (ggml_backend_cpu_kleidiai_buffer_type()) { + bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type()); + } +#endif + #ifdef GGML_USE_CPU_AARCH64 if (ggml_backend_cpu_aarch64_buffer_type()) { bufts.push_back(ggml_backend_cpu_aarch64_buffer_type()); @@ -538,6 +548,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt()); features.push_back({ "SVE_CNT", sve_cnt.c_str() }); } + if (ggml_cpu_has_sme()) { + features.push_back({ "SME", "1" }); + } if (ggml_cpu_has_riscv_v()) { features.push_back({ "RISCV_V", "1" }); } @@ -559,6 +572,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r #ifdef GGML_USE_OPENMP features.push_back({ "OPENMP", "1" }); #endif + #ifdef GGML_USE_CPU_KLEIDIAI + features.push_back({ "KLEIDIAI", "1" }); + #endif #ifdef GGML_USE_CPU_AARCH64 features.push_back({ "AARCH64_REPACK", "1" }); #endif diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp new file mode 100644 index 00000000000..a8a59a887cb --- /dev/null +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -0,0 +1,259 @@ +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT +// + +// KleidiAI micro-kernels +#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" +#include "kai_lhs_quant_pack_qsi8d32p_f32.h" +#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" +#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" +#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" +#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" +#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" +#include "kai_common.h" + +#include "kernels.h" + +#define NELEMS(x) sizeof(x) / sizeof(*x) +static ggml_kleidiai_kernels gemm_gemv_kernels[] = { +#if defined(__ARM_FEATURE_SME) + { + /* SME GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + }, + /* SME GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, + /* .require_aligned_m_idx = */ true, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, + }, + /* .required_cpu = */ CPU_FEATURE_SME, + }, +#endif +#if defined(__APPLE__) +#if defined(__ARM_FEATURE_DOTPROD) + { + /* DOTPROD GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + }, + /* DOTPROD GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .require_aligned_m_idx = */ false, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + }, + /* .required_cpu = */ CPU_FEATURE_DOTPROD, + }, +#endif +#if defined(__ARM_FEATURE_MATMUL_INT8) + { + /* i8mm GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + }, + /* i8mm GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .require_aligned_m_idx = */ false, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + }, + /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + }, +#endif +#else +#if defined(__ARM_FEATURE_MATMUL_INT8) + { + /* i8mm GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + }, + /* i8mm GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .require_aligned_m_idx = */ false, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + }, + /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + }, +#endif +#if defined(__ARM_FEATURE_DOTPROD) + { + /* DOTPROD GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + }, + /* DOTPROD GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .require_aligned_m_idx = */ false, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + }, + /* .required_cpu = */ CPU_FEATURE_DOTPROD, + }, +#endif +#endif +}; + +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) { + ggml_kleidiai_kernels * kernels = nullptr; + + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { + if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) { + kernels = &gemm_gemv_kernels[i]; + break; + } + } + + return kernels; +} diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h new file mode 100644 index 00000000000..a0b0d149344 --- /dev/null +++ b/ggml/src/ggml-cpu/kleidiai/kernels.h @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT +// + +#pragma once + +enum cpu_feature { + CPU_FEATURE_NONE = 0, + CPU_FEATURE_DOTPROD = 1, + CPU_FEATURE_I8MM = 2, + CPU_FEATURE_SVE = 4, + CPU_FEATURE_SME = 8 +}; +inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) { + lhs = static_cast(lhs | rhs); + return lhs; +} +inline cpu_feature operator|(cpu_feature lhs, cpu_feature rhs) { + return static_cast(static_cast(lhs) | static_cast(rhs)); +} + +struct kernel_info { + size_t (*get_m_step)(void); + size_t (*get_n_step)(void); + size_t (*get_mr)(void); + size_t (*get_nr)(void); + size_t (*get_kr)(void); + size_t (*get_sr)(void); + size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl); + size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl); + size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride); + size_t (*get_dst_size)(size_t m, size_t n); + void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, + float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); +}; + +struct lhs_packing_info { + size_t (*get_offset)(size_t m_idx, size_t lhs_stride); + size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, + size_t lhs_stride, void* lhs_packed); + bool require_aligned_m_idx; +}; + +struct rhs_packing_info { + size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl); + void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, + const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params); +}; + +struct ggml_kleidiai_kernels { + kernel_info gemm; + kernel_info gemv; + lhs_packing_info lhs_info; + rhs_packing_info rhs_info; + + cpu_feature required_cpu; +}; + +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features); diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp new file mode 100644 index 00000000000..66685fd1661 --- /dev/null +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -0,0 +1,287 @@ +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT +// +#include +#include +#include +#include +#include +#if defined(__linux__) +#include +#include +#elif defined(__APPLE__) +#include +#include +#include +#elif defined(_WIN32) +#include +#include +#endif + +#include "kleidiai.h" + +#include "ggml-cpu.h" +#include "ggml-impl.h" +#include "ggml-backend-impl.h" +#include "ggml-threading.h" +#include "ggml-cpu-traits.h" + +#include "kernels.h" + +#include "kai_common.h" + +#define GGML_COMMON_DECL_CPP +#include "ggml-common.h" + +struct ggml_kleidiai_context { + ggml_kleidiai_kernels * kernels; +} static ctx = { NULL }; + +static void init_kleidiai_context(void) { + + ggml_critical_section_start(); + static bool initialized = false; + + if (!initialized) { + initialized = true; + const char *env_var = getenv("GGML_KLEIDIAI_SME"); + int sme_enabled = 0; + + cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | + (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) | + (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); + + if (env_var) { + sme_enabled = atoi(env_var); + } + + if (sme_enabled != 0) { + features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; + } + ctx.kernels = ggml_kleidiai_select_kernels(features); + } + ggml_critical_section_end(); +} + +static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { + GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); + return tensor->ne[dim]; +} + +namespace ggml::cpu::kleidiai { +class tensor_traits : public ggml::cpu::tensor_traits { + bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + GGML_ASSERT(ctx.kernels); + kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm; + + size_t k = op->src[0]->ne[0]; + size_t m = op->src[1]->ne[1]; + + size_t mr = kernel->get_mr(); + size_t kr = kernel->get_kr(); + size_t sr = kernel->get_sr(); + + size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr); + + return true; + } + + bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override { + if (dst->op == GGML_OP_MUL_MAT) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(ctx.kernels); + kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm; + lhs_packing_info * lhs_info = &ctx.kernels->lhs_info; + + GGML_ASSERT(kernel); + + const int ith = params->ith; + const int nth = params->nth; + + const size_t k = ne00; + const size_t m = ne11; + const size_t n = ne01; + + const size_t n_step = kernel->get_n_step(); + const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); + const size_t n_start = ith * num_n_per_thread; + + size_t n_to_process = num_n_per_thread; + if ((n_start + n_to_process) > n) { + n_to_process = n - n_start; + } + + const uint8_t * lhs = static_cast(src1->data); + uint8_t * lhs_packed = (uint8_t*)params->wdata; + const uint8_t * rhs_packed = static_cast(src0->data); + + size_t mr = kernel->get_mr(); + size_t kr = kernel->get_kr(); + size_t sr = kernel->get_sr(); + + // Calculate number of columns to be processed per thread + const bool use_multithread = lhs_info->require_aligned_m_idx && m <= mr ? false : true; + const size_t num_m_per_thread = use_multithread ? kai_roundup(m, nth) / nth : m; + const size_t m_start = ith * num_m_per_thread; + size_t m_to_process = num_m_per_thread; + if ((m_start + m_to_process) > m) { + m_to_process = m - m_start; + } + + if(m_start < m) { + // Transform LHS + const size_t src_stride = src1->nb[1]; + const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1])); + const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr); + void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); + + lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr); + } + + ggml_barrier(params->threadpool); + + // Perform the operation + const size_t dst_stride = dst->nb[1]; + const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr); + const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0); + const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); + const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); + const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); + float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); + + kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, + dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + return true; + } + return false; + } + +public: + int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) { + GGML_ASSERT(ctx.kernels); + const size_t n = tensor->ne[1]; + const size_t k = tensor->ne[0]; + size_t nr = ctx.kernels->gemm.get_nr(); + size_t kr = ctx.kernels->gemm.get_kr(); + size_t sr = ctx.kernels->gemm.get_sr(); + +#ifndef NDEBUG + const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0); + GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!"); +#endif + struct kai_rhs_pack_qs4cxs1s0_param params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, ¶ms); + + return 0; + + GGML_UNUSED(data_size); + } +}; + +static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) { + static tensor_traits traits; + return &traits; +} +} // namespace ggml::cpu::kleidiai + +static void ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, + const void * data, size_t offset, size_t size) { + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + + auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra; + auto OK = tensor_traits->repack(tensor, data, size); + + GGML_ASSERT(OK == 0); + GGML_UNUSED(buffer); +} + +static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU_KLEIDIAI"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + + if (buffer == nullptr) { + return nullptr; + } + + buffer->buft = buft; + buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor; + buffer->iface.set_tensor = ggml_backend_cpu_kleidiai_buffer_set_tensor; + buffer->iface.get_tensor = nullptr; + buffer->iface.cpy_tensor = nullptr; + return buffer; +} + +static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return TENSOR_ALIGNMENT; + + GGML_UNUSED(buft); +} + +namespace ggml::cpu::kleidiai { +class extra_buffer_type : ggml::cpu::extra_buffer_type { + bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + if ( op->op == GGML_OP_MUL_MAT && + op->src[0]->type == GGML_TYPE_Q4_0 && + op->src[0]->buffer && + (ggml_n_dims(op->src[0]) == 2) && + op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels + ) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32 && + ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) { + return true; + } + } + return false; + } + + ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { + if (op->op == GGML_OP_MUL_MAT) { + if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { + return (ggml::cpu::tensor_traits *) op->src[0]->extra; + } + } + return nullptr; + } +}; +} // namespace ggml::cpu::kleidiai + +ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) { + static ggml::cpu::kleidiai::extra_buffer_type ctx; + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_kleidiai_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment, + /* .get_max_size = */ nullptr, // defaults to SIZE_MAX + /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes + /* .is_host = */ nullptr, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ &ctx, + }; + + init_kleidiai_context(); + + return &ggml_backend_cpu_buffer_type_kleidiai; +} diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.h b/ggml/src/ggml-cpu/kleidiai/kleidiai.h new file mode 100644 index 00000000000..38eac58f7c2 --- /dev/null +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.h @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "ggml-alloc.h" + +#ifdef __cplusplus +extern "C" { +#endif + +ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void); + +#ifdef __cplusplus +} +#endif From 1d80839d4f4c5c1db1797135767fecb9a28aa815 Mon Sep 17 00:00:00 2001 From: Bodhi <3882561+BodhiHu@users.noreply.github.com> Date: Fri, 21 Feb 2025 15:46:23 +0800 Subject: [PATCH 47/58] MUSA: support ARM64 and enable dp4a .etc (llama/11843) * MUSA: support ARM64 and enable __dp4a .etc * fix cross entropy loss op for musa * update * add cc info log for musa * add comment for the MUSA .cc calculation block --------- Co-authored-by: Bodhi Hu --- ggml/src/ggml-cuda/common.cuh | 6 +++--- ggml/src/ggml-cuda/cross-entropy-loss.cu | 10 +++++----- ggml/src/ggml-cuda/ggml-cuda.cu | 10 ++++++---- ggml/src/ggml-impl.h | 2 +- ggml/src/ggml-musa/CMakeLists.txt | 2 +- 5 files changed, 16 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 4a92d35f9f4..7e99838c092 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -411,13 +411,13 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) -#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A +#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA) return __dp4a(a, b, c); -#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A +#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA) const int8_t * a8 = (const int8_t *) &a; const int8_t * b8 = (const int8_t *) &b; return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3]; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA) #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) } diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cu b/ggml/src/ggml-cuda/cross-entropy-loss.cu index 27599a2b038..0ce4afbb222 100644 --- a/ggml/src/ggml-cuda/cross-entropy-loss.cu +++ b/ggml/src/ggml-cuda/cross-entropy-loss.cu @@ -123,13 +123,13 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_cuda_pool_alloc dst_tmp(pool, blocks_num.x); if (nbytes_shared <= smpbo) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; if (!shared_memory_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); + CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); shared_memory_limit_raised[id] = true; } -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); } else { cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); @@ -175,13 +175,13 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten const size_t smpbo = ggml_cuda_info().devices[id].smpbo; if (nbytes_shared <= smpbo) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; if (!shared_memory_limit_raised[id]) { CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); shared_memory_limit_raised[id] = true; } -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); } else { cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 093ad70991b..cc772801e03 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -261,6 +261,12 @@ static ggml_cuda_device_info ggml_cuda_init() { GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n", id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, device_vmm ? "yes" : "no", prop.warpSize); +#elif defined(GGML_USE_MUSA) + // TODO: refine the .cc to reflect MUSA's actual CC capabilities + info.devices[id].smpbo = prop.sharedMemPerBlockOptin; + info.devices[id].cc = 100*prop.major + 10*prop.minor; + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); #else info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; @@ -1782,9 +1788,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co } } #else -#ifdef GGML_USE_MUSA - GGML_ASSERT(false); -#else // !GGML_USE_MUSA if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 // use cublasGemmStridedBatchedEx @@ -1827,7 +1830,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } -#endif // GGML_USE_MUSA #endif if (dst->op_params[0] == GGML_PREC_DEFAULT) { diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index eab017889c9..1fbcbd0456e 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -16,7 +16,7 @@ #include #endif // __ARM_FEATURE_SVE -#if defined(__ARM_NEON) && !defined(__CUDACC__) +#if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__) // if YCM cannot find , make a symbolic link to it, for example: // // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index 2f555416e62..1bfc07c5d71 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -49,7 +49,7 @@ if (MUSAToolkit_FOUND) set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX) foreach(SOURCE ${GGML_SOURCES_MUSA}) - set(COMPILE_FLAGS "-x musa -mtgpu") + set(COMPILE_FLAGS "-fsigned-char -x musa -mtgpu") foreach(ARCH ${MUSA_ARCHITECTURES}) set(COMPILE_FLAGS "${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}") endforeach() From 3852692c27c500ab3aaf8f66dd64c9402d9dac21 Mon Sep 17 00:00:00 2001 From: PureJourney Date: Fri, 21 Feb 2025 19:21:05 +0800 Subject: [PATCH 48/58] CUDA: correct the lowest Maxwell supported by CUDA 12 (llama/11984) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * CUDA: correct the lowest Maxwell supported by CUDA 12 --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 682640b5208..e63ede2fbe3 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -7,7 +7,7 @@ if (CUDAToolkit_FOUND) if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) # native == GPUs available at build time - # 52 == Maxwell, lowest CUDA 12 standard + # 50 == Maxwell, lowest CUDA 12 standard # 60 == P100, FP16 CUDA intrinsics # 61 == Pascal, __dp4a instruction (per-byte integer dot product) # 70 == V100, FP16 tensor cores @@ -17,7 +17,7 @@ if (CUDAToolkit_FOUND) elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80") else() - set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75;80") + set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75;80") endif() endif() message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") From 4fdcb182a61e062fb9c9b3339833757a95aa542e Mon Sep 17 00:00:00 2001 From: Gian-Carlo Pascutto Date: Sat, 22 Feb 2025 09:43:24 +0100 Subject: [PATCH 49/58] cuda: Add Q5_1, Q5_0, Q4_1 and Q4_0 to F32 conversion support. (llama/12000) --- ggml/src/ggml-cuda/cpy.cu | 99 ++++++++++++++++++++++++++++++--- ggml/src/ggml-cuda/ggml-cuda.cu | 12 ++++ 2 files changed, 104 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 54c0f66d2df..cca2bee0b27 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -1,4 +1,5 @@ #include "cpy.cuh" +#include "dequantize.cuh" typedef void (*cpy_kernel_t)(const char * cx, char * cdst); @@ -82,13 +83,14 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { } static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { - const block_q8_0 * xi = (const block_q8_0 *) cxi; - float * dsti = (float *) cdsti; - - const float d = (float)xi->d; - - for (int j = 0; j < QK8_0; j++) { - dsti[j] = xi->qs[j] * d; + float * cdstf = (float *)(cdsti); + +#pragma unroll + for (int j = 0; j < QK8_0; j += 2) { + dfloat2 dq; + dequantize_q8_0(cxi, 0, j, dq); + *(cdstf + j) = dq.x; + *(cdstf + j + 1) = dq.y; } } @@ -225,6 +227,18 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) { memcpy(dsti->qh, &qh, sizeof(qh)); } +template +static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { + float * cdstf = (float *)(cdsti); + +#pragma unroll + for (int j = 0; j < qk/2; j++) { + dfloat2 dq; + dequant(cxi, 0, j, dq); + *(cdstf + j) = dq.x; + *(cdstf + j + qk/2) = dq.y; + } +} static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) { if (x <= val[0]) return 0; @@ -387,6 +401,19 @@ static void ggml_cpy_f32_q4_0_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } +static void ggml_cpy_q4_0_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream) { + const int num_blocks = ne; + cpy_q_f32, QK4_0><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + static void ggml_cpy_f32_q4_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -398,6 +425,19 @@ static void ggml_cpy_f32_q4_1_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } +static void ggml_cpy_q4_1_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream) { + const int num_blocks = ne; + cpy_q_f32, QK4_1><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + static void ggml_cpy_f32_q5_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -409,6 +449,19 @@ static void ggml_cpy_f32_q5_0_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } +static void ggml_cpy_q5_0_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream) { + const int num_blocks = ne; + cpy_q_f32, QK5_0><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + static void ggml_cpy_f32_q5_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -420,6 +473,19 @@ static void ggml_cpy_f32_q5_1_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } +static void ggml_cpy_q5_1_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream) { + const int num_blocks = ne; + cpy_q_f32, QK5_1><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + static void ggml_cpy_f32_iq4_nl_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -488,14 +554,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { @@ -524,14 +601,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_q_f32; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32, QK4_0>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32, QK4_1>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32, QK5_0>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32, QK5_1>; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { return (void*) cpy_f32_f16; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index cc772801e03..f685423215b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3075,15 +3075,27 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) { return true; } + if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) { return true; } + if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) { return true; } + if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) { return true; } + if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) { return true; } From e4f3c48104bba774640c4d22504afa942fd79a1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 22 Feb 2025 12:20:17 +0100 Subject: [PATCH 50/58] CUDA: optimize FA for GQA + large batches (llama/12014) --- ggml/src/ggml-cuda/cp-async.cuh | 2 +- ggml/src/ggml-cuda/fattn-common.cuh | 122 ++- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 804 ++++++++++++------ ggml/src/ggml-cuda/fattn-tile-f16.cu | 4 +- ggml/src/ggml-cuda/fattn-tile-f32.cu | 4 +- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 2 +- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 2 +- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 6 +- ggml/src/ggml-cuda/fattn.cu | 73 +- ggml/src/ggml-cuda/mma.cuh | 75 ++ .../fattn-mma-f16-instance-cpb16.cu | 10 - .../fattn-mma-f16-instance-cpb32.cu | 10 - .../fattn-mma-f16-instance-cpb64.cu | 10 - .../fattn-mma-f16-instance-cpb8.cu | 10 - ...attn-mma-f16-instance-ncols1_1-ncols2_8.cu | 10 + ...ttn-mma-f16-instance-ncols1_16-ncols2_1.cu | 10 + ...ttn-mma-f16-instance-ncols1_16-ncols2_2.cu | 10 + ...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 10 + ...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 10 + ...attn-mma-f16-instance-ncols1_2-ncols2_8.cu | 10 + ...ttn-mma-f16-instance-ncols1_32-ncols2_1.cu | 10 + ...ttn-mma-f16-instance-ncols1_32-ncols2_2.cu | 10 + ...attn-mma-f16-instance-ncols1_4-ncols2_2.cu | 10 + ...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 10 + ...attn-mma-f16-instance-ncols1_4-ncols2_8.cu | 10 + ...ttn-mma-f16-instance-ncols1_64-ncols2_1.cu | 10 + ...attn-mma-f16-instance-ncols1_8-ncols2_1.cu | 10 + ...attn-mma-f16-instance-ncols1_8-ncols2_2.cu | 10 + ...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 10 + ...attn-mma-f16-instance-ncols1_8-ncols2_8.cu | 10 + .../template-instances/generate_cu_files.py | 20 +- 31 files changed, 919 insertions(+), 395 deletions(-) delete mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh index 51aa41e7e60..ecb659997ba 100644 --- a/ggml/src/ggml-cuda/cp-async.cuh +++ b/ggml/src/ggml-cuda/cp-async.cuh @@ -24,7 +24,7 @@ static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, co } else #endif // CUDART_VERSION >= 11040 { - asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;" + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;" : : "r"(dst), "l"(src)); } #else diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index fefbd319baf..7b9566fb4be 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -516,27 +516,25 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { nullptr; } -// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional. -#ifdef __clang__ -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wpass-failed" -#endif // __clang__ - -template // D == head size -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +template // D == head size __launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_stream_k_fixup( float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { - const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); - - const int iter_k = ne11 / KQ_stride; - const int iter_j = (ne01 + (ncols - 1)) / ncols; + constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; + const int j = blockIdx.y; + const int c = blockIdx.z; + const int jc = j*ncols2 + c; + const int tid = threadIdx.x; + + const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); + + const int iter_k = ne11 / FATTN_KQ_STRIDE; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x; - const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x; + const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; @@ -548,22 +546,22 @@ static __global__ void flash_attn_stream_k_fixup( const int channel = kbc0 / (iter_k*iter_j); const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; - dst += jt*ncols*ne02*D + channel*D; + if (jt*ncols1 + j >= ne01) { + return; + } - // Load the partial result that needs a fixup: - float dst_val[ncols] = {0.0f}; - float max_val[ncols] = {0.0f}; - float rowsum[ncols] = {0.0f}; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - if (jt*ncols + j >= ne01) { - break; - } - dst_val[j] = dst[j*ne02*D + threadIdx.x]; + dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid; - const float2 tmp = dst_fixup[bidx0*ncols + j]; - max_val[j] = tmp.x; - rowsum[j] = tmp.y; + // Load the partial result that needs a fixup: + float dst_val = 0.0f; + float max_val = 0.0f; + float rowsum = 0.0f; + { + dst_val = *dst; + + const float2 tmp = dst_fixup[bidx0*ncols + jc]; + max_val = tmp.x; + rowsum = tmp.y; } // Iterate over previous blocks and compute the combined results. @@ -571,36 +569,30 @@ static __global__ void flash_attn_stream_k_fixup( int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { - const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x; + const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; continue; } -#pragma unroll - for (int j = 0; j < ncols; ++j) { - if (jt*ncols + j >= ne01) { - break; - } - const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x]; + const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; - const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j]; + const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc]; - // Scale the current and new value accumulators depending on the max. values. - const float max_val_new = fmaxf(max_val[j], tmp.x); + // Scale the current and new value accumulators depending on the max. values. + const float max_val_new = fmaxf(max_val, tmp.x); - const float diff_val = max_val[j] - max_val_new; - const float diff_add = tmp.x - max_val_new; + const float diff_val = max_val - max_val_new; + const float diff_add = tmp.x - max_val_new; - const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; - const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; - dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add; - rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y; + dst_val = scale_val*dst_val + scale_add*dst_add; + rowsum = scale_val*rowsum + scale_add*tmp.y; - max_val[j] = max_val_new; - } + max_val = max_val_new; // If this block started in a previous tile we are done and don't need to combine additional partial results. if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { @@ -611,19 +603,9 @@ static __global__ void flash_attn_stream_k_fixup( } // Write back final result: -#pragma unroll - for (int j = 0; j < ncols; ++j) { - if (jt*ncols + j >= ne01) { - return; - } - dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j]; - } + *dst = dst_val / rowsum; } -#ifdef __clang__ -#pragma clang diagnostic pop -#endif // __clang__ - template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) @@ -690,11 +672,13 @@ static void on_no_fattn_vec_case(const int D) { } // parallel_blocks == 0 is stream-k decomposition -template +template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V ) { + constexpr int ncols = ncols1 * ncols2; + const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; @@ -763,25 +747,26 @@ void launch_fattn( nb23 = nb23*bs*sizeof(half)/ts; } - const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block); - const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3]; + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); + const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; const dim3 block_dim(WARP_SIZE, nwarps, 1); dim3 blocks_num; if (parallel_blocks == 0) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. - const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm); - const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves); + const int max_blocks = 2*nsm; + const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; + const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); - const int nblocks_stream_k = 2*nsm; + const int nblocks_stream_k = max_blocks; - const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE; + const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75; blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; blocks_num.y = 1; blocks_num.z = 1; - dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float)); + dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float)); } else { blocks_num.x = parallel_blocks*ntiles_x; blocks_num.y = Q->ne[2]; @@ -793,7 +778,6 @@ void launch_fattn( } } - float scale = 1.0f; float max_bias = 0.0f; float logit_softcap = 0.0f; @@ -832,9 +816,9 @@ void launch_fattn( if constexpr (parallel_blocks == 0) { if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. const dim3 block_dim_combine(D, 1, 1); - const dim3 blocks_num_combine = blocks_num; + const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; - flash_attn_stream_k_fixup + flash_attn_stream_k_fixup <<>> ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index d777f5413ed..b2e0db9a2cc 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -5,12 +5,15 @@ using namespace ggml_cuda_mma; -typedef tile<16, 8, half2> tile_A; -typedef tile< 8, 8, half2> tile_B; -typedef tile<16, 8, float> tile_C_KQ; -typedef tile<16, 4, half2> tile_C_VKQ; - -template +typedef tile<16, 8, half2> tile_A; +typedef tile< 8, 8, half2> tile_B; +typedef tile<16, 8, half2> tile_B_16; +typedef tile<16, 8, float> tile_C_KQ; +typedef tile<16, 16, float> tile_C_KQ_16; +typedef tile<16, 4, half2> tile_C_VKQ; +typedef tile<16, 8, half2> tile_C_VKQ_16; + +template static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) { constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. @@ -27,7 +30,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( constexpr int chunks_per_row = k0_sync_start / h2_per_chunk; constexpr int stride_i = WARP_SIZE / chunks_per_row; #pragma unroll - for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) { + for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row); const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk; @@ -40,7 +43,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( // If D is not a power of 2, the rest is loaded synchronously. // K/V data is loaded with decreasing granularity for D for better memory bandwidth. - static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); + static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds"); #pragma unroll for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k); @@ -52,7 +55,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( } #pragma unroll - for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) { + for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); #pragma unroll @@ -65,12 +68,54 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( } } -template +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( + const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { + static_assert(KQ_per_iter == 2*WARP_SIZE || KQ_per_iter == WARP_SIZE, "bad KQ_per_iter"); +#ifdef CP_ASYNC_AVAILABLE + constexpr int preload = KQ_per_iter * sizeof(half); + constexpr int cols_per_warp = 8*WARP_SIZE/KQ_per_iter; + constexpr int stride_j = nwarps * cols_per_warp; + + const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask); + +#pragma unroll + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + + (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/8)); + + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } + + const int i = 4 * (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x % (WARP_SIZE/4) : threadIdx.x % (WARP_SIZE/8)); + + cp_async_cg_16(tile_mask_32 + j*(KQ_per_iter*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + } +#else + constexpr int cols_per_warp = 2*WARP_SIZE/KQ_per_iter; + constexpr int stride_j = nwarps * cols_per_warp; +#pragma unroll + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + (KQ_per_iter == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/2)); + + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } + + const int i = KQ_per_iter == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/2); + + tile_mask[j*(KQ_per_iter/2 + 4) + i] = mask_h2[j*stride_mask + i]; + } +#endif // CP_ASYNC_AVAILABLE +} + +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, - const half * const __restrict__ maskh, + const half2 * const __restrict__ mask_h2, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, @@ -78,42 +123,60 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float logit_softcap, const int ne01, const int ne02, - const int stride_Q, const int stride_KV, const int stride_mask, const int jt, half2 * const __restrict__ tile_K, half2 * const __restrict__ tile_V, + half2 * const __restrict__ tile_mask, const tile_B * const __restrict__ Q_B, tile_C_VKQ * const __restrict__ VKQ_C, - float2 & KQ_max, - float2 & KQ_rowsum, + float * const __restrict__ KQ_max, + float * const __restrict__ KQ_rowsum, const int kb0) { #ifdef NEW_MMA_AVAILABLE - constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column. - constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + + const int k_VKQ_0 = kb0 * KQ_per_iter; + tile_C_KQ KQ_C[KQ_per_iter/(np*tile_C_KQ::I) * ntiles]; - const int k_VKQ_0 = kb0*KQ_stride; - tile_C_KQ KQ_C[KQ_stride/(np*tile_C_KQ::I)]; + // Use wide variants of tiles if ntiles >= 2. + tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; + tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; #ifdef CP_ASYNC_AVAILABLE cp_async_wait_all(); __syncthreads(); - flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); + flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); #else - flash_attn_ext_f16_load_tile(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV); + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV); __syncthreads(); #endif // CP_ASYNC_AVAILABLE // Calculate tile of KQ: #pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*tile_A::I) { + for (int i_KQ_00 = 0; i_KQ_00 < KQ_per_iter; i_KQ_00 += np*tile_A::I) { const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) { tile_A K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded); - mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, ((tile_B *) Q_B)[k_KQ_0/tile_A::J]); + if (ntiles == 1) { + mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); + } + } } } @@ -122,9 +185,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // CP_ASYNC_AVAILABLE if (use_logit_softcap) { - static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int i = 0; i < KQ_stride/(np*tile_C_KQ::I); ++i) { + for (int i = 0; i < KQ_per_iter/(np*tile_C_KQ::I) * ntiles; ++i) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); @@ -132,109 +195,209 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - if (maskh) { - static_assert(KQ_stride % (np *tile_C_KQ::I) == 0, "bad loop size"); - static_assert(ncols % (nwarps/np*tile_C_KQ::J) == 0, "bad loop size"); + float KQ_max_new[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max_new[col] = KQ_max[col]; + } + float KQ_rowsum_add[cols_per_thread] = {0.0f}; + + if (ntiles == 1) { + if (ncols2 > 1 || mask_h2) { +#pragma unroll + for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + const int i = i0 + tile_C_KQ::get_i(l); + const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; + + KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * + __half2float(((const half *) tile_mask)[j*(KQ_per_iter + 8) + i]); + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); + } + } + + // Values per KQ column are spread across 8 threads, does not need full warp reduce: #pragma unroll - for (int i00 = 0; i00 < KQ_stride; i00 += np*tile_C_KQ::I) { - const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = 16; offset >= 4; offset >>= 1) { + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + } + } + + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); + +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { - const int i = i0 + tile_C_KQ::get_i(l); - const int j = (threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l); + KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); - KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]); + KQ_rowsum_add[l % 2] += KQ_C[k].x[l]; + } + } + } else { // ntiles > 1 + if (ncols2 > 1 || mask_h2) { +#pragma unroll + for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ_16::J) { + const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) { + const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; + const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; + + const float2 tmp = __half22float2(tile_mask[j*(KQ_per_iter/2 + 4) + i]); + const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; + KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; + KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; + } + } } } - } - // Calculate softmax for each KQ column using the current max. value. - // The divisor is stored in KQ_rowsum and will be applied at the end. - float2 KQ_max_new = KQ_max; - static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size"); + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { #pragma unroll - for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) { + for (int t = 0; t < ntiles/2; ++t) { #pragma unroll - for (int l0 = 0; l0 < tile_C_KQ::ne; l0 += 2) { - KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]); - KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]); + for (int l = 0; l < tile_C_KQ_16::ne; ++l) { + const int KQ_index = 2*t + (l/2) % 2; + KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]); + } + } } - } - // Values per KQ column are spread across 8 threads, does not need full warp reduce: + // Values per KQ column are spread across 4 threads, does not need full warp reduce: #pragma unroll - for (int offset = 16; offset > 2; offset >>= 1) { - KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE)); - KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE)); - } + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = 2; offset >= 1; offset >>= 1) { + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + } + } - float2 KQ_rowsum_add = make_float2(0.0f, 0.0f); - static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(KQ_per_iter % (np*tile_C_KQ_16::J) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { #pragma unroll - for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) { + for (int t = 0; t < ntiles/2; ++t) { #pragma unroll - for (int l = 0; l < tile_C_KQ::ne; ++l) { - const float KQ_max_l = l % 2 == 0 ? KQ_max_new.x : KQ_max_new.y; - const float diff = KQ_C[k].x[l] - KQ_max_l; - KQ_C[k].x[l] = expf(diff); + for (int l = 0; l < tile_C_KQ_16::ne; ++l) { + const int KQ_index = 2*t + (l/2) % 2; - if (l % 2 == 0) { - KQ_rowsum_add.x += KQ_C[k].x[l]; - } else { - KQ_rowsum_add.y += KQ_C[k].x[l]; + KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]); + + KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l]; + } } } } { - const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y); - const float2 KQ_max_scale = make_float2(expf(diff.x), expf(diff.y)); - KQ_max = KQ_max_new; + float KQ_max_scale[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]); + KQ_max[col] = KQ_max_new[col]; - // Scale previous KQ_rowsum to account for a potential increase in KQ_max: - KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x; - KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y; + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; + } - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y); + if (ntiles == 1) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll - for (int i = 0; i < D/tile_C_VKQ::I; ++i) { + for (int i = 0; i < D/tile_C_VKQ::I; ++i) { #pragma unroll - for (int l = 0; l < tile_C_VKQ::ne; ++l) { - VKQ_C[i].x[l] *= KQ_max_scale_h2; + for (int l = 0; l < tile_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { + VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + } + } } } } // Convert KQ C tiles into B tiles for VKQ calculation: - tile_B B[KQ_stride/(np*2*tile_B::J)]; - static_assert(KQ_stride % (np*2*tile_B::J) == 0, "bad loop size"); + tile_B B[KQ_per_iter/(np*2*tile_B::J) * ntiles]; + tile_B_16 * B_16 = (tile_B_16 *) B; + static_assert(KQ_per_iter % (np*2*tile_B::J) == 0, "bad loop size"); + if (ntiles == 1) { #pragma unroll - for (int k = 0; k < KQ_stride/(np*2*tile_B::J); ++k) { - B[k] = get_transposed(get_half2(KQ_C[k])); + for (int k = 0; k < KQ_per_iter/(np*2*tile_B::J); ++k) { + B[k] = get_transposed(get_half2(KQ_C[k])); + } + } else { + for (int k = 0; k < KQ_per_iter/(np*2*tile_B_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); + } + } } #ifdef CP_ASYNC_AVAILABLE + // Preload K tile for next iteration: cp_async_wait_all(); __syncthreads(); if (!last_iter) { - flash_attn_ext_f16_load_tile(K_h2 + (k_VKQ_0 + KQ_stride)*stride_KV, tile_K, stride_KV); + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + (k_VKQ_0 + KQ_per_iter)/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV); } #else - flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); + flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); __syncthreads(); #endif // CP_ASYNC_AVAILABLE // Calculate VKQ tile: #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) { - static_assert((KQ_stride/2) % (np*tile_A::J) == 0, "bad loop size"); + static_assert((KQ_per_iter/2) % (np*tile_A::J) == 0, "bad loop size"); #pragma unroll - for (int k00 = 0; k00 < KQ_stride/2; k00 += np*tile_A::J) { + for (int k00 = 0; k00 < KQ_per_iter/2; k00 += np*tile_A::J) { const int k0 = k00 + (threadIdx.y % np)*tile_A::J; tile_A A; load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); - mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + if (ntiles == 1) { + mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of VKQ_C is column-major => swap A and B. + mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); + } + } } } @@ -247,12 +410,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // NEW_MMA_AVAILABLE } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, - const half * const __restrict__ maskh, + const half2 * const __restrict__ mask_h2, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, @@ -260,7 +423,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float logit_softcap, const int ne01, const int ne02, - const int stride_Q, + const int stride_Q1, + const int stride_Q2, const int stride_KV, const int stride_mask, const int jt, @@ -269,63 +433,78 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #ifdef NEW_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - static_assert(nwarps*tile_B::I % ncols == 0, "bad nwarps"); - constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column. + constexpr int ncols = ncols1 * ncols2; + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + + static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); - static_assert(D % nwarps == 0, "bad D"); - static_assert(KQ_stride % nwarps == 0, "bad KQ_stride"); + static_assert(D % nwarps == 0, "bad D"); + static_assert(KQ_per_iter % nwarps == 0, "bad KQ_per_iter"); constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. - // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements: + // Temporary shared buffer for loading K/V data with KQ_per_iter*D logical elements: extern __shared__ half2 tile_K[]; #ifdef CP_ASYNC_AVAILABLE - half2 * tile_V = tile_K + KQ_stride*D2_padded; + half2 * tile_V = tile_K + KQ_per_iter*D2_padded; #else - half2 * tile_V = tile_K; + half2 * tile_V = tile_K; #endif // CP_ASYNC_AVAILABLE + half2 * tile_mask = tile_V + KQ_per_iter*D2_padded; - tile_B Q_B[D/(2*tile_B::J)]; - tile_C_VKQ VKQ_C[D/tile_C_VKQ::I]; + tile_B Q_B[D/(2*tile_B::J) * ntiles]; + tile_C_VKQ VKQ_C[D/tile_C_VKQ::I * ntiles]; - float2 KQ_rowsum = {0.0f, 0.0f}; - float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f}; + tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; + tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + + float KQ_rowsum[cols_per_thread] = {0.0f}; + float KQ_max[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max[col] = -FLT_MAX/2.0f; + } // Temporarily load Q data into tile_K, will be loaded into registers afterwards. // The loading is done with decreasing granularity for D for better memory bandwidth. const half2 scale_h2 = make_half2(scale, scale); #pragma unroll for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); - const int stride_j = WARP_SIZE / stride_k; + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; if (k0_start == k0_stop) { continue; } - if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { - break; - } - #pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) { - const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) { + const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) { + break; + } + + const int j = jc / ncols2; + const int c = jc % ncols2; - if (jt*ncols + j < ne01) { + if (jt*ncols1 + j < ne01) { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k]; - tile_K[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); + const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; + tile_K[jc*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); } } else { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_K[j*D2_padded + k] = make_half2(0.0f, 0.0f); + tile_K[jc*D2_padded + k] = make_half2(0.0f, 0.0f); } } } @@ -334,128 +513,217 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); { - const int j0 = (threadIdx.y / np) * tile_B::I; + const int j0 = (threadIdx.y / np) * cols_per_warp; #pragma unroll for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { - load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded); + if (ntiles == 1) { + load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], + tile_K + (j0 + t*tile_B_16::I)*D2_padded + k0, D2_padded); + } + } } } __syncthreads(); - // Preload K data for first iteration when using cp_async: + // Preload mask and K data for first iteration when using cp_async: #ifdef CP_ASYNC_AVAILABLE - flash_attn_ext_f16_load_tile(K_h2 + kb0_start*KQ_stride*stride_KV, tile_K, stride_KV); + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + kb0_start*KQ_per_iter/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV); #endif // CP_ASYNC_AVAILABLE // Iterate over ne11 == previous tokens: for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; - flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. constexpr bool last_iter = true; - flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); } // With cp_async there is no __syncthreads at the end of the iter, // there can be a race condition on shared memory access for combining/writing back results. #ifdef CP_ASYNC_AVAILABLE - if (nwarps*tile_B::I > KQ_stride) { + if (nwarps*cols_per_warp > KQ_per_iter) { __syncthreads(); } #endif // CP_ASYNC_AVAILABLE // Finally, sum up partial KQ rowsums. - // The partial sums are spread across 8 threads each, does not need full reduce. + // The partial sums are spread across 8/4 threads each, does not need full reduce. + { + constexpr int offset_first = ntiles == 1 ? 16 : 2; + constexpr int offset_last = ntiles == 1 ? 4 : 1; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll - for (int offset = 16; offset > 2; offset >>= 1) { - KQ_rowsum.x += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.x, offset, WARP_SIZE); - KQ_rowsum.y += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.y, offset, WARP_SIZE); + for (int offset = offset_first; offset >= offset_last; offset >>= 1) { + KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); + } + } } // Write VKQ accumulators to shared memory in column-major format. // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // Also for np > 1 the combination is done via these values in shared memory. - const int j_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // j combine write data + if (ntiles == 1) { + const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data #pragma unroll - for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { - const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. + for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. #pragma unroll - for (int l = 0; l < tile_B::ne; ++l) { - const int k = k0 + tile_B::get_j(l); + for (int l = 0; l < tile_B::ne; ++l) { + const int k = k0 + tile_B::get_j(l); - tile_K[j_cwd*D2_padded + k] = B.x[l]; + tile_K[jc_cwd*D2_padded + k] = B.x[l]; + } + } + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { + const int j = j0 + tile_C_VKQ_16::get_i(l); + const int k = k0 + tile_C_VKQ_16::get_j(l); + + tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; + } + } } } - const int j_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // j combine write meta offset - const int j_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta - const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum + if constexpr (ntiles == 1) { + const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset + const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta + const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum - if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { - // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. - ((float2 *) tile_K)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr; - } + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + } - __syncthreads(); + __syncthreads(); - static_assert(np == 1 || np == 2 || np == 4, "bad np"); - if (np == 1) { - // No combination is needed, the meta data can be directly written from registers to VRAM. - if (needs_fixup && threadIdx.x < tile_B::I) { - float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; - dstk_fixup_meta[j_cwm] = KQ_cmr; + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && threadIdx.x < tile_B::I) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + if (is_fixup && threadIdx.x < tile_B::I) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } } - if (is_fixup && threadIdx.x < tile_B::I) { - float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; - dstk_fixup_meta[j_cwm] = KQ_cmr; + } else { + static_assert(ntiles == 2 || ntiles == 4, "bad ntiles"); + const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta + + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0) + + tile_C_VKQ_16::get_i(threadIdx.x % 4); + const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum + + if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + } + + __syncthreads(); + + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } } - } else if (threadIdx.y % np == 0) { + } + + static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles"); + if (np > 1 && threadIdx.y % np == 0) { // Combine the meta data for parallel warps via shared memory. // Warps with threadIdx.y % np != 0 must NOT return early. // All threads must return simultaneously to avoid race conditions with work on the next tile. - float * meta_j = (float *) tile_K + (threadIdx.y*tile_B::I + threadIdx.x)*D2_padded + D/2; + constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; - float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp. - if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) { - KQ_cm = meta_j[0]; + const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); + float2 * const meta_ptr = ((float2 *) tile_K) + jc_meta*(D2_padded/2) + D/4; + float2 meta[nmeta]; +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + meta[imeta] = meta_ptr[imeta * WARP_SIZE * D2_padded/2]; } - float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps. + float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. +#pragma unroll + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x); + } #pragma unroll - for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) { + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset >= WARP_SIZE) { + continue; + } KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); } - const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp. - float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps. - if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) { - KQ_crs = KQ_cms*meta_j[1]; + float KQ_cms[nmeta]; // KQ combine max scale per warp. +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn); } + + float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps. #pragma unroll - for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) { + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_crs += KQ_cms[imeta]*meta[imeta].y; + } +#pragma unroll + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset >= WARP_SIZE) { + continue; + } KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } // Write back combined meta data: - if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) { - *((float2 *) meta_j) = make_float2(KQ_cms, KQ_crs); // Combined KQ max scale + rowsum. +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { + // Combined KQ max scale + rowsum. + meta_ptr[imeta * WARP_SIZE * D2_padded/2] = make_float2(KQ_cms[imeta], KQ_crs); + } } - if (needs_fixup && threadIdx.x < tile_B::I) { + + // Combined KQ max + rowsum. + static_assert(cols_per_warp <= WARP_SIZE); + if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; - dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } - if (is_fixup && threadIdx.x < tile_B::I) { + if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; - dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } } @@ -470,27 +738,32 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); - const int stride_j = WARP_SIZE / stride_k; + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; if (k0_start == k0_stop) { continue; } - if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { - break; - } - #pragma unroll - for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) { - const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - const int j_tile_K = (j_dst/tile_B::I)*(np*tile_B::I) + j_dst % tile_B::I; + for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { + const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { + break; + } + + const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; + + const int j_dst = jc_dst / ncols2; + const int c_dst = jc_dst % ncols2; - if (!is_fixup && jt*ncols + j_dst >= ne01) { + if (!is_fixup && jt*ncols1 + j_dst >= ne01) { continue; } - const float * meta_j = (const float *) tile_K + j_tile_K*D2_padded + D/2; + + const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2; #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); @@ -498,8 +771,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float2 dstk_val = make_float2(0.0f, 0.0f); #pragma unroll for (int ip = 0; ip < np; ++ip) { - const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*tile_B::I*D2_padded + 0]; - const float2 dstk_val_add = __half22float2(tile_K[(j_tile_K + ip*tile_B::I)*D2_padded + k]); + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0]; + const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]); dstk_val.x += dstk_val_add.x*KQ_crs; dstk_val.y += dstk_val_add.y*KQ_crs; } @@ -511,9 +784,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } if (is_fixup) { - dstk_fixup_data[j_dst*(D/2) + k] = dstk_val; + dstk_fixup_data[jc_dst*(D/2) + k] = dstk_val; } else { - dstk[(jt*ncols + j_dst)*ne02*(D/2) + k] = dstk_val; + dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(D/2) + k] = dstk_val; } } } @@ -528,10 +801,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #endif // NEW_MMA_AVAILABLE } -template -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +template __launch_bounds__(nwarps*WARP_SIZE, 2) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -579,20 +850,23 @@ static __global__ void flash_attn_ext_f16( return; } - static_assert(FATTN_KQ_STRIDE % KQ_stride == 0, "bad KQ_stride"); + static_assert(FATTN_KQ_STRIDE % KQ_per_iter == 0, "bad KQ_per_iter"); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const int stride_Q = nb01 / sizeof(float2); + const int stride_Q1 = nb01 / sizeof(float2); + const int stride_Q2 = nb02 / sizeof(float2); const int stride_KV = nb11 / sizeof(half2); - const int stride_mask = nb31 / sizeof(half); + const int stride_mask = nb31 / sizeof(half2); + + const int iter_k = ne11 / FATTN_KQ_STRIDE; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - const int iter_k = ne11 / KQ_stride; - const int iter_j = (ne01 + (ncols - 1)) / ncols; + constexpr int kb_niter = FATTN_KQ_STRIDE / KQ_per_iter; // Number of kernel iterations per assigned KQ slice. // kbc == k block continuous, current index in continuous ijk space. - int kbc = (blockIdx.x + 0)*iter_k*iter_j*ne02 / gridDim.x; - const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*ne02 / gridDim.x; + int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). @@ -605,25 +879,28 @@ static __global__ void flash_attn_ext_f16( const int channel = kbc / (iter_k*iter_j); const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb02* channel); - const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape - const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr; - float2 * dstk = ((float2 *) dst) + channel*(D/2); + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); - const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1); + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + + const int kb0_start_kernel = kb0_start * kb_niter; + const int kb0_stop_kernel = kb0_stop * kb_niter; constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop); + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop); + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } kbc += iter_k; @@ -640,39 +917,46 @@ static __global__ void flash_attn_ext_f16( const int channel = kbc / (iter_k*iter_j); const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb02* channel); - const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape - const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr; - float2 * dstk = ((float2 *) dst) + channel*(D/2); + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; - const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1); + const int kb0_start_kernel = kb0_start * kb_niter; + const int kb0_stop_kernel = kb0_stop * kb_niter; constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop); + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } -template +template void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - typedef tile<16, 8, half2> tile_A; - typedef tile< 8, 8, half2> tile_B; + constexpr int ncols = ncols1 * ncols2; + constexpr int KQ_per_iter = D <= 128 && ncols1 <= 64 ? 64 : 32; + constexpr int nwarps = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4; + constexpr int ntiles = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4); + constexpr int cols_per_warp = ntiles * tile_B::I; - static_assert(D % tile_B::J == 0, "bad D"); - static_assert(cols_per_block % tile_B::I == 0, "bad cols_per_block"); + static_assert(D % tile_B::J == 0, "bad D"); + static_assert(ncols % cols_per_warp == 0, "bad ncols"); const ggml_tensor * KQV = dst; - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + + const int KQ_shared_rows = cp_async_available(cc) ? 2*KQ_per_iter : KQ_per_iter; - constexpr int KQ_stride = D <= 128 ? 64 : 32; - constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ? - cols_per_block/tile_B::J * KQ_stride/tile_A::I : (cols_per_block <= 8 ? 4 : 8); + const size_t nbytes_shared_KV = KQ_shared_rows * (D + 8) * sizeof(half); + const size_t nbytes_shared_mask = ncols1 * (KQ_per_iter + 8) * sizeof(half); + const size_t nbytes_shared_combine = nwarps*cols_per_warp * (D + 8) * sizeof(half); - const int nrows_KQ = cp_async_available(cc) ? 2*KQ_stride : KQ_stride; - const int nrows_combine = nwarps*tile_B::J; - const size_t nbytes_shared = std::max(nrows_KQ, nrows_combine) * (D + 8) * sizeof(half); + const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine); float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); @@ -680,42 +964,58 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; } - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); + + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true); } -#define DECL_FATTN_MMA_F16_CASE(D, cols_per_block) \ + +#define DECL_FATTN_MMA_F16_CASE(D, ncols1, ncols2) \ template void ggml_cuda_flash_attn_ext_mma_f16_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ - -extern DECL_FATTN_MMA_F16_CASE( 64, 8); -extern DECL_FATTN_MMA_F16_CASE( 80, 8); -extern DECL_FATTN_MMA_F16_CASE( 96, 8); -extern DECL_FATTN_MMA_F16_CASE(112, 8); -extern DECL_FATTN_MMA_F16_CASE(128, 8); -extern DECL_FATTN_MMA_F16_CASE(256, 8); - -extern DECL_FATTN_MMA_F16_CASE( 64, 16); -extern DECL_FATTN_MMA_F16_CASE( 80, 16); -extern DECL_FATTN_MMA_F16_CASE( 96, 16); -extern DECL_FATTN_MMA_F16_CASE(112, 16); -extern DECL_FATTN_MMA_F16_CASE(128, 16); -extern DECL_FATTN_MMA_F16_CASE(256, 16); - -extern DECL_FATTN_MMA_F16_CASE( 64, 32); -extern DECL_FATTN_MMA_F16_CASE( 80, 32); -extern DECL_FATTN_MMA_F16_CASE( 96, 32); -extern DECL_FATTN_MMA_F16_CASE(112, 32); -extern DECL_FATTN_MMA_F16_CASE(128, 32); -extern DECL_FATTN_MMA_F16_CASE(256, 32); - -extern DECL_FATTN_MMA_F16_CASE( 64, 64); -extern DECL_FATTN_MMA_F16_CASE( 80, 64); -extern DECL_FATTN_MMA_F16_CASE( 96, 64); -extern DECL_FATTN_MMA_F16_CASE(112, 64); -extern DECL_FATTN_MMA_F16_CASE(128, 64); -extern DECL_FATTN_MMA_F16_CASE(256, 64); + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(D, ncols) \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/1, 1); \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/2, 2); \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \ + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8); + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16); + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32); + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64); + +// Kernels with ncols == 128 are only 4% faster due to register pressure. +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128); +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128); +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128); +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128); +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128); +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128); // Needs too much shared memory. diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index d4edbad07f2..b8b415effb7 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -302,14 +302,14 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 0d274f33255..4352a284464 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -296,14 +296,14 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index d9ac4424606..e758a0f6ec2 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); } template diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 6ef8f9dcc27..134144a383f 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -290,7 +290,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); } template diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 45702ad651f..de38470abec 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -478,7 +478,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm fattn_kernel = flash_attn_ext_f16< D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; } - launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); return; } if (2*blocks_num_pb1 < 2*nsm) { @@ -493,7 +493,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm fattn_kernel = flash_attn_ext_f16< D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; } - launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); return; } constexpr int parallel_blocks = 1; @@ -507,7 +507,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm fattn_kernel = flash_attn_ext_f16< D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; } - launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); } void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index b0cf152f52c..b1becccb4de 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -8,28 +8,50 @@ #include "fattn-wmma-f16.cuh" #include "fattn.cuh" -template +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + if (Q->ne[1] <= 8/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + if (Q->ne[1] <= 16/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + if (Q->ne[1] <= 32/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); +} + +template static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst); break; default: GGML_ABORT("fatal error"); @@ -38,24 +60,35 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context } static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * mask = dst->src[3]; + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const float use_gqa_opt = mask && max_bias == 0.0f; + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; - if (Q->ne[1] <= 8) { + if (use_gqa_opt && gqa_ratio % 8 == 0) { ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); return; } - if (Q->ne[1] <= 16) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst); + if (use_gqa_opt && gqa_ratio == 4) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst); return; } - if (Q->ne[1] <= 32) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst); + if (use_gqa_opt && gqa_ratio == 2) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst); return; } - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst); } #define FATTN_VEC_F16_CASE(D, type_K, type_V) \ @@ -209,8 +242,11 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; @@ -252,7 +288,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { + const int gqa_ratio = Q->ne[2] / K->ne[2]; + const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 && + K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask; + if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) { if (prec == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); return; diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 0a5656e4cb3..9206bfeba3d 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -73,6 +73,8 @@ namespace ggml_cuda_mma { return threadIdx.x / 4; } else if constexpr (I == 16 && J == 8) { return (l / 2) * 8 + threadIdx.x / 4; + } else if constexpr (I == 16 && J == 16) { + return ((l / 2) % 2) * 8 + threadIdx.x / 4; } else { static_assert(I == -1 && J == -1, "template specialization not implemented"); } @@ -85,6 +87,8 @@ namespace ggml_cuda_mma { return 4 * l + threadIdx.x % 4; } else if constexpr (I == 16 && J == 8) { return 2 * (threadIdx.x % 4) + l % 2; + } else if constexpr (I == 16 && J == 16) { + return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2; } else { static_assert(I == -1 && J == -1, "template specialization not implemented"); } @@ -289,6 +293,42 @@ namespace ggml_cuda_mma { #endif // NEW_MMA_AVAILABLE } + static __device__ __forceinline__ void mma( + tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { +#ifdef NEW_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3])); +#else + // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + } + static __device__ __forceinline__ void mma( tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { #ifdef NEW_MMA_AVAILABLE @@ -316,4 +356,39 @@ namespace ggml_cuda_mma { #endif // NEW_MMA_AVAILABLE } + static __device__ __forceinline__ void mma( + tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { +#ifdef NEW_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3])); +#else + // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + } } diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu deleted file mode 100644 index f09bdeff79a..00000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu +++ /dev/null @@ -1,10 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-mma-f16.cuh" - -DECL_FATTN_MMA_F16_CASE(64, 16); -DECL_FATTN_MMA_F16_CASE(80, 16); -DECL_FATTN_MMA_F16_CASE(96, 16); -DECL_FATTN_MMA_F16_CASE(112, 16); -DECL_FATTN_MMA_F16_CASE(128, 16); -DECL_FATTN_MMA_F16_CASE(256, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu deleted file mode 100644 index 221108873a0..00000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu +++ /dev/null @@ -1,10 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-mma-f16.cuh" - -DECL_FATTN_MMA_F16_CASE(64, 32); -DECL_FATTN_MMA_F16_CASE(80, 32); -DECL_FATTN_MMA_F16_CASE(96, 32); -DECL_FATTN_MMA_F16_CASE(112, 32); -DECL_FATTN_MMA_F16_CASE(128, 32); -DECL_FATTN_MMA_F16_CASE(256, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu deleted file mode 100644 index d24b085758d..00000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu +++ /dev/null @@ -1,10 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-mma-f16.cuh" - -DECL_FATTN_MMA_F16_CASE(64, 64); -DECL_FATTN_MMA_F16_CASE(80, 64); -DECL_FATTN_MMA_F16_CASE(96, 64); -DECL_FATTN_MMA_F16_CASE(112, 64); -DECL_FATTN_MMA_F16_CASE(128, 64); -DECL_FATTN_MMA_F16_CASE(256, 64); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu deleted file mode 100644 index bdf86c0eaba..00000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu +++ /dev/null @@ -1,10 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-mma-f16.cuh" - -DECL_FATTN_MMA_F16_CASE(64, 8); -DECL_FATTN_MMA_F16_CASE(80, 8); -DECL_FATTN_MMA_F16_CASE(96, 8); -DECL_FATTN_MMA_F16_CASE(112, 8); -DECL_FATTN_MMA_F16_CASE(128, 8); -DECL_FATTN_MMA_F16_CASE(256, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu new file mode 100644 index 00000000000..80108615ac8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 1, 8); +DECL_FATTN_MMA_F16_CASE(80, 1, 8); +DECL_FATTN_MMA_F16_CASE(96, 1, 8); +DECL_FATTN_MMA_F16_CASE(112, 1, 8); +DECL_FATTN_MMA_F16_CASE(128, 1, 8); +DECL_FATTN_MMA_F16_CASE(256, 1, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu new file mode 100644 index 00000000000..66161c0abeb --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16, 1); +DECL_FATTN_MMA_F16_CASE(80, 16, 1); +DECL_FATTN_MMA_F16_CASE(96, 16, 1); +DECL_FATTN_MMA_F16_CASE(112, 16, 1); +DECL_FATTN_MMA_F16_CASE(128, 16, 1); +DECL_FATTN_MMA_F16_CASE(256, 16, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu new file mode 100644 index 00000000000..ee88c72aa04 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16, 2); +DECL_FATTN_MMA_F16_CASE(80, 16, 2); +DECL_FATTN_MMA_F16_CASE(96, 16, 2); +DECL_FATTN_MMA_F16_CASE(112, 16, 2); +DECL_FATTN_MMA_F16_CASE(128, 16, 2); +DECL_FATTN_MMA_F16_CASE(256, 16, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu new file mode 100644 index 00000000000..d888a5a423a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16, 4); +DECL_FATTN_MMA_F16_CASE(80, 16, 4); +DECL_FATTN_MMA_F16_CASE(96, 16, 4); +DECL_FATTN_MMA_F16_CASE(112, 16, 4); +DECL_FATTN_MMA_F16_CASE(128, 16, 4); +DECL_FATTN_MMA_F16_CASE(256, 16, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu new file mode 100644 index 00000000000..d93a2d08ed7 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 2, 4); +DECL_FATTN_MMA_F16_CASE(80, 2, 4); +DECL_FATTN_MMA_F16_CASE(96, 2, 4); +DECL_FATTN_MMA_F16_CASE(112, 2, 4); +DECL_FATTN_MMA_F16_CASE(128, 2, 4); +DECL_FATTN_MMA_F16_CASE(256, 2, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu new file mode 100644 index 00000000000..617464c9456 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 2, 8); +DECL_FATTN_MMA_F16_CASE(80, 2, 8); +DECL_FATTN_MMA_F16_CASE(96, 2, 8); +DECL_FATTN_MMA_F16_CASE(112, 2, 8); +DECL_FATTN_MMA_F16_CASE(128, 2, 8); +DECL_FATTN_MMA_F16_CASE(256, 2, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu new file mode 100644 index 00000000000..970d2b68696 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 32, 1); +DECL_FATTN_MMA_F16_CASE(80, 32, 1); +DECL_FATTN_MMA_F16_CASE(96, 32, 1); +DECL_FATTN_MMA_F16_CASE(112, 32, 1); +DECL_FATTN_MMA_F16_CASE(128, 32, 1); +DECL_FATTN_MMA_F16_CASE(256, 32, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu new file mode 100644 index 00000000000..65cd377c395 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 32, 2); +DECL_FATTN_MMA_F16_CASE(80, 32, 2); +DECL_FATTN_MMA_F16_CASE(96, 32, 2); +DECL_FATTN_MMA_F16_CASE(112, 32, 2); +DECL_FATTN_MMA_F16_CASE(128, 32, 2); +DECL_FATTN_MMA_F16_CASE(256, 32, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu new file mode 100644 index 00000000000..f4a8bf34899 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 4, 2); +DECL_FATTN_MMA_F16_CASE(80, 4, 2); +DECL_FATTN_MMA_F16_CASE(96, 4, 2); +DECL_FATTN_MMA_F16_CASE(112, 4, 2); +DECL_FATTN_MMA_F16_CASE(128, 4, 2); +DECL_FATTN_MMA_F16_CASE(256, 4, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu new file mode 100644 index 00000000000..de191a8ab66 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 4, 4); +DECL_FATTN_MMA_F16_CASE(80, 4, 4); +DECL_FATTN_MMA_F16_CASE(96, 4, 4); +DECL_FATTN_MMA_F16_CASE(112, 4, 4); +DECL_FATTN_MMA_F16_CASE(128, 4, 4); +DECL_FATTN_MMA_F16_CASE(256, 4, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu new file mode 100644 index 00000000000..e8cb0e1b312 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 4, 8); +DECL_FATTN_MMA_F16_CASE(80, 4, 8); +DECL_FATTN_MMA_F16_CASE(96, 4, 8); +DECL_FATTN_MMA_F16_CASE(112, 4, 8); +DECL_FATTN_MMA_F16_CASE(128, 4, 8); +DECL_FATTN_MMA_F16_CASE(256, 4, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu new file mode 100644 index 00000000000..a532e96296b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 1); +DECL_FATTN_MMA_F16_CASE(80, 64, 1); +DECL_FATTN_MMA_F16_CASE(96, 64, 1); +DECL_FATTN_MMA_F16_CASE(112, 64, 1); +DECL_FATTN_MMA_F16_CASE(128, 64, 1); +DECL_FATTN_MMA_F16_CASE(256, 64, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu new file mode 100644 index 00000000000..bf25181aa76 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 1); +DECL_FATTN_MMA_F16_CASE(80, 8, 1); +DECL_FATTN_MMA_F16_CASE(96, 8, 1); +DECL_FATTN_MMA_F16_CASE(112, 8, 1); +DECL_FATTN_MMA_F16_CASE(128, 8, 1); +DECL_FATTN_MMA_F16_CASE(256, 8, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu new file mode 100644 index 00000000000..378c132e658 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 2); +DECL_FATTN_MMA_F16_CASE(80, 8, 2); +DECL_FATTN_MMA_F16_CASE(96, 8, 2); +DECL_FATTN_MMA_F16_CASE(112, 8, 2); +DECL_FATTN_MMA_F16_CASE(128, 8, 2); +DECL_FATTN_MMA_F16_CASE(256, 8, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu new file mode 100644 index 00000000000..372641be9a0 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 4); +DECL_FATTN_MMA_F16_CASE(80, 8, 4); +DECL_FATTN_MMA_F16_CASE(96, 8, 4); +DECL_FATTN_MMA_F16_CASE(112, 8, 4); +DECL_FATTN_MMA_F16_CASE(128, 8, 4); +DECL_FATTN_MMA_F16_CASE(256, 8, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu new file mode 100644 index 00000000000..9ff5968b6ab --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 8); +DECL_FATTN_MMA_F16_CASE(80, 8, 8); +DECL_FATTN_MMA_F16_CASE(96, 8, 8); +DECL_FATTN_MMA_F16_CASE(112, 8, 8); +DECL_FATTN_MMA_F16_CASE(128, 8, 8); +DECL_FATTN_MMA_F16_CASE(256, 8, 8); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index a2628f16e57..dd373a09d26 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -18,7 +18,7 @@ """ -SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n" +SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n" TYPES_MMQ = [ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", @@ -57,12 +57,18 @@ def get_head_sizes(type_k, type_v): with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v)) -for cols_per_block in [8, 16, 32, 64]: - with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f: - f.write(SOURCE_FATTN_MMA_START) - - for head_size in [64, 80, 96, 112, 128, 256]: - f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size)) +for ncols in [8, 16, 32, 64, 128]: + for ncols2 in [1, 2, 4, 8]: + ncols1 = ncols // ncols2 + if ncols == 128: + continue # Too much register pressure. + with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f: + f.write(SOURCE_FATTN_MMA_START) + + for head_size in [64, 80, 96, 112, 128, 256]: + if ncols == 128 and head_size == 256: + continue # Needs too much shared memory. + f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size)) for type in TYPES_MMQ: with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: From 51bb2f9b2ad7f07549199ba4abd20ac6fb5df1d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 22 Feb 2025 20:44:34 +0100 Subject: [PATCH 51/58] CUDA: app option to compile without FlashAttention (llama/12025) --- ggml/CMakeLists.txt | 1 + ggml/src/ggml-cuda/CMakeLists.txt | 4 ++++ ggml/src/ggml-cuda/common.cuh | 4 ++-- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 8 ++++---- ggml/src/ggml-cuda/fattn-tile-f16.cu | 9 ++------- ggml/src/ggml-cuda/fattn-tile-f32.cu | 8 ++++---- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 9 ++------- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 8 ++++---- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 4 ++-- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- ggml/src/ggml-hip/CMakeLists.txt | 4 ++++ ggml/src/ggml-musa/CMakeLists.txt | 4 ++++ 12 files changed, 34 insertions(+), 31 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index fc5eac151b9..12afe0f25a8 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -151,6 +151,7 @@ set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING "ggml: max. batch size for using peer access") option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF) option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF) +option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON) option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF) option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT}) diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index e63ede2fbe3..96bd5a0be29 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -69,6 +69,10 @@ if (CUDAToolkit_FOUND) add_compile_definitions(GGML_CUDA_NO_VMM) endif() + if (NOT GGML_CUDA_FA) + add_compile_definitions(GGML_CUDA_NO_FA) + endif() + if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) add_compile_definitions(GGML_CUDA_F16) endif() diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 7e99838c092..adf0d3ecb56 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -204,9 +204,9 @@ typedef float2 dfloat2; #define CP_ASYNC_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE -#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) +#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) #define FLASH_ATTN_AVAILABLE -#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) +#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) static bool fp16_available(const int cc) { return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL; diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index b2e0db9a2cc..718ee5402dc 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -839,10 +839,7 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { -#ifndef NEW_MMA_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // NEW_MMA_AVAILABLE +#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { @@ -933,6 +930,9 @@ static __global__ void flash_attn_ext_f16( flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); +#else + NO_DEVICE_CODE; +#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) } template diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index b8b415effb7..ef3569fab27 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -44,12 +44,7 @@ static __global__ void flash_attn_tile_ext_f16( const int ne1, const int ne2, const int ne3) { -#ifdef FP16_AVAILABLE - -#ifndef FLASH_ATTN_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FLASH_ATTN_AVAILABLE +#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) // Skip unused kernel variants for faster compilation: #ifdef FP16_MMA_AVAILABLE @@ -290,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f16( } #else NO_DEVICE_CODE; -#endif // FP16_AVAILABLE +#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) } template diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 4352a284464..04b69c83be0 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -44,10 +44,7 @@ static __global__ void flash_attn_tile_ext_f32( const int ne1, const int ne2, const int ne3) { -#ifndef FLASH_ATTN_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FLASH_ATTN_AVAILABLE +#ifdef FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: #ifdef FP16_MMA_AVAILABLE @@ -285,6 +282,9 @@ static __global__ void flash_attn_tile_ext_f32( dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); } } +#else + NO_DEVICE_CODE; +#endif // FLASH_ATTN_AVAILABLE } template diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index e758a0f6ec2..b7686c1ec3d 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -41,12 +41,7 @@ static __global__ void flash_attn_vec_ext_f16( const int ne1, const int ne2, const int ne3) { -#ifdef FP16_AVAILABLE - -#ifndef FLASH_ATTN_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FLASH_ATTN_AVAILABLE +#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { @@ -300,7 +295,7 @@ static __global__ void flash_attn_vec_ext_f16( } #else NO_DEVICE_CODE; -#endif // FP16_AVAILABLE +#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) } template diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 134144a383f..c1d2dd8d19f 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -41,10 +41,7 @@ static __global__ void flash_attn_vec_ext_f32( const int ne1, const int ne2, const int ne3) { -#ifndef FLASH_ATTN_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FLASH_ATTN_AVAILABLE +#ifdef FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { @@ -281,6 +278,9 @@ static __global__ void flash_attn_vec_ext_f32( if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } +#else + NO_DEVICE_CODE; +#endif // FLASH_ATTN_AVAILABLE } template diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index de38470abec..8828652fb5e 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -51,7 +51,7 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -425,7 +425,7 @@ static __global__ void flash_attn_ext_f16( } #else NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA } constexpr int get_max_power_of_2(int x) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index f685423215b..ebb2ccae040 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3203,7 +3203,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_FLASH_ATTN_EXT: { #ifndef FLASH_ATTN_AVAILABLE return false; -#endif +#endif // FLASH_ATTN_AVAILABLE if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { return false; } diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index f4a4683639f..4a0384dd476 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -107,6 +107,10 @@ if (GGML_HIP_NO_VMM) add_compile_definitions(GGML_HIP_NO_VMM) endif() +if (NOT GGML_CUDA_FA) + add_compile_definitions(GGML_CUDA_NO_FA) +endif() + if (CXX_IS_HIPCC) set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) target_link_libraries(ggml-hip PRIVATE hip::device) diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index 1bfc07c5d71..2c75abf61d6 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -83,6 +83,10 @@ if (MUSAToolkit_FOUND) add_compile_definitions(GGML_CUDA_NO_VMM) endif() + if (NOT GGML_CUDA_FA) + add_compile_definitions(GGML_CUDA_NO_FA) + endif() + if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) add_compile_definitions(GGML_CUDA_F16) endif() From c28f40165975d203bf87f4f4b3d0680b37c5ebc3 Mon Sep 17 00:00:00 2001 From: Aaron Teo <57927438+taronaeo@users.noreply.github.com> Date: Sun, 23 Feb 2025 05:39:24 +0800 Subject: [PATCH 52/58] ggml-cpu: Support s390x SIMD Instruction Set (llama/12019) * ggml: add s390x ARCH_FLAGS for compilation Signed-off-by: Aaron Teo * ggml: add SIMD for s390x using vector intrinsics SIMD is activated for: * ggml_vec_dot_f32 * ggml_vec_dot_f16 * ggml_vec_mad_f32 * ggml_vec_mad_f16 * ggml_vec_mad_f32_unroll * ggml_vec_scale_f32 * ggml_vec_scale_f16 SIMD is NOT activated for: * ggml_vec_dot_f16_unroll (pending bugfix) Signed-off-by: Aaron Teo * ggml: fix missing escape character in GGML_F32x4_REDUCE Signed-off-by: Aaron Teo * ggml: add temporary patch for GGML_F32_ARR and GGML_F16_ARR Signed-off-by: Aaron Teo * ggml: fix s390x GGML_F32x4_REDUCE Signed-off-by: Aaron Teo * ggml: full SIMD activation for F32,F16 s390x Signed-off-by: Aaron Teo * ggml: add option to disable s390x VXE/VXE2 Signed-off-by: Aaron Teo * ggml: change vecintrin.h include to ggml-cpu-impl * add __VXE__ and __VXE2__ macros Signed-off-by: Aaron Teo * cmake: add s390x target detection for VX/VXE/VXE2 Signed-off-by: Aaron Teo * ggml: move s390x vector intrinsics to ggml-cpu-impl.h Signed-off-by: Aaron Teo * ggml: s390x Q8_0 SIMD Signed-off-by: Aaron Teo * ggml: correct documentation for Q8_0 Signed-off-by: Aaron Teo * ggml: s390x reduce code complexity Q8_0 Signed-off-by: Aaron Teo * ggml: s390x bugfix typo Q8_0 Signed-off-by: Aaron Teo * ggml: s390x SIMD activated for Q4_1 Signed-off-by: Aaron Teo * ggml: s390x inline vec_reve Signed-off-by: Aaron Teo * ggml: s390x SIMD activation for Q4_0 Signed-off-by: Aaron Teo * ggml: add VXE backend feature Signed-off-by: Aaron Teo * ggml: remove test.py Signed-off-by: Aaron Teo * ggml: s390x SIMD activation for quantize_row_q8_0 Signed-off-by: Aaron Teo * ggml: s390x SIMD activation for quantize_row_q8_1 Signed-off-by: Aaron Teo * ggml: s390x SIMD activation for iq4_xs Signed-off-by: Aaron Teo * ggml: bugfix iq4_xs Signed-off-by: Aaron Teo * ggml: s390x SIMD activation for iq4_nl Signed-off-by: Aaron Teo * ggml: add float, double, and long vector data type Signed-off-by: Aaron Teo * ggml: clean up iq4_xs SIMD Signed-off-by: Aaron Teo * ggml: fix improper use of restrict keyword Signed-off-by: Aaron Teo * ggml: update warning message for ggml_vec_tbl Signed-off-by: Aaron Teo * ggml: untested implementation of ggml_vec_dot_iq2_xxs_q8_K Signed-off-by: Aaron Teo * ggml: update ggml_vec_dot_q4_1_q8_1 to use typedefs Signed-off-by: Aaron Teo * ggml: switch to restrict for iq4_nl Signed-off-by: Aaron Teo * ggml: slight dot product speed improvement for q4_1_q8_1 Signed-off-by: Aaron Teo * ggml: s390x SIMD activation for q6_K Signed-off-by: Aaron Teo * ggml: add missing `_t` to ggml_int8x16x4_t Signed-off-by: Aaron Teo * ggml: fix missing `_t` for ggml_vec_xl_s8x4 Signed-off-by: Aaron Teo * ggml: fix more missing `_t` Signed-off-by: Aaron Teo * ggml: add unroll and prefetch to Q8_0 increase of 3.86% for prompt processing and 32.22% for token generation Signed-off-by: Aaron Teo * ggml: patch Q8_0 to use proper vector sizes Signed-off-by: Aaron Teo * ggml: optimise Q8_0 dot prod compute kernel further Signed-off-by: Aaron Teo * ggml: add unroll and prefetch to Q4_1 Signed-off-by: Aaron Teo * ggml: refactor Q6_K variable naming for readability Signed-off-by: Aaron Teo * ggml: fix Q6_K typos Signed-off-by: Aaron Teo * ggml: s390x SIMD activation for Q5_K Signed-off-by: Aaron Teo * ggml: fix wrong char*x16_t naming Signed-off-by: Aaron Teo * ggml: Q5_K y0 wrong signness Signed-off-by: Aaron Teo * ggml: fix Q5_K invalid uchar type Signed-off-by: Aaron Teo * ggml: fix Q5_K invalid uchar type Signed-off-by: Aaron Teo * ggml: s390x SIMD activation for Q4_K Signed-off-by: Aaron Teo * ggml: fix Q4_K invalid vector intrinsics Signed-off-by: Aaron Teo * ggml: simplify ggml_padd_s16 compute kernel Signed-off-by: Aaron Teo * ggml: correct ggml-cpu vxe wording Signed-off-by: Aaron Teo * ggml: change ggml_aligned_malloc alignment to 256 256 is the cache line size for s390x platforms Signed-off-by: Aaron Teo * ggml: resolve pr merge via cherry-pick 225bbbf Signed-off-by: Aaron Teo * ggml : fix LoongArch compile error with 128-bit SIMD (llama/11701) * ggml: resolve pr merge via cherry-pick 4571953 Signed-off-by: Aaron Teo * ggml: cmake remove fork when determining s390x machine type thank you @ericcurtin Signed-off-by: Aaron Teo --------- Signed-off-by: Aaron Teo Co-authored-by: Jinyang He Co-authored-by: junchao-zhao <68935141+junchao-loongson@users.noreply.github.com> --- ggml/CMakeLists.txt | 1 + ggml/include/ggml-cpu.h | 1 + ggml/src/ggml-cpu/CMakeLists.txt | 21 ++ ggml/src/ggml-cpu/ggml-cpu-impl.h | 151 ++++++++ ggml/src/ggml-cpu/ggml-cpu-quants.c | 555 +++++++++++++++++++++++++++- ggml/src/ggml-cpu/ggml-cpu.c | 91 +++++ ggml/src/ggml-cpu/ggml-cpu.cpp | 3 + ggml/src/ggml.c | 4 + 8 files changed, 826 insertions(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 12afe0f25a8..68b3f148eaf 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -122,6 +122,7 @@ endif() option(GGML_LASX "ggml: enable lasx" ON) option(GGML_LSX "ggml: enable lsx" ON) option(GGML_RVV "ggml: enable rvv" ON) +option(GGML_VXE "ggml: enable vxe" ON) option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index 9b8a697546e..b48cc560e52 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -99,6 +99,7 @@ extern "C" { // other GGML_BACKEND_API int ggml_cpu_has_riscv_v (void); GGML_BACKEND_API int ggml_cpu_has_vsx (void); + GGML_BACKEND_API int ggml_cpu_has_vxe (void); GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void); GGML_BACKEND_API int ggml_cpu_has_llamafile (void); diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 826d65cece0..f8836ed61b9 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -306,6 +306,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_RVV) list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d) endif() + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") + message(STATUS "s390x detected") + file(READ "/proc/cpuinfo" CPUINFO_CONTENTS) + string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS}) + + # TODO: Separation to determine activation of VX/VXE/VXE2 + if (${S390X_M} MATCHES "8561|8562") + message(STATUS "z15 target") + list(APPEND ARCH_FLAGS -march=z15 -mtune=z15) + elseif (${S390X_M} MATCHES "3931") + message(STATUS "z16 target") + list(APPEND ARCH_FLAGS -march=z16 -mtune=z16) + else() + message(STATUS "Unknown target") + message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.") + list(APPEND ARCH_FLAGS -march=native -mtune=native) + endif() + + if (GGML_VXE) + list(APPEND ARCH_FLAGS -mvx -mzvector) + endif() else() message(STATUS "Unknown architecture") endif() diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index 9ddd972a5cf..7f7d210cbe5 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -59,6 +59,15 @@ struct ggml_compute_params { #endif #endif +#if defined(__s390x__) && defined(__VEC__) +#ifndef __VXE__ +#define __VXE__ +#endif +#ifndef __VXE2__ +#define __VXE2__ +#endif +#endif + #if defined(__ARM_FEATURE_SVE) #include #include @@ -359,6 +368,148 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) #endif #endif +#if defined(__VXE__) || defined(__VXE2__) +#include + +#define vec_neg(a) (-(a)) // Vector Negate +#define vec_add(a, b) ((a) + (b)) // Vector Add +#define vec_sub(a, b) ((a) - (b)) // Vector Subtract +#define vec_mul(a, b) ((a) * (b)) // Vector Multiply +#define vec_div(a, b) ((a) / (b)) // Vector Divide +#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left +#define vec_sra(a, b) ((a) >> (b)) // Vector Shift Right +#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebraic +#define vec_slo(a, b) vec_slb(a, (b) << 64) // Vector Shift Left by Octet +#define vec_sro(a, b) vec_srb(a, (b) << 64) // Vector Shift Right by Octet + +#ifndef vec_and +#define vec_and(a, b) ((a) & (b)) // Vector AND +#endif + +#ifndef vec_or +#define vec_or(a, b) ((a) | (b)) // Vector OR +#endif + +#ifndef vec_xor +#define vec_xor(a, b) ((a) ^ (b)) // Vector XOR +#endif + +typedef signed char char8x16_t __attribute__((vector_size(16))); +typedef unsigned char uchar8x16_t __attribute__((vector_size(16))); + +typedef int8_t int8x16_t __attribute__((vector_size(16))); +typedef int16_t int16x8_t __attribute__((vector_size(16))); +typedef int32_t int32x4_t __attribute__((vector_size(16))); + +typedef uint8_t uint8x16_t __attribute__((vector_size(16))); +typedef uint16_t uint16x8_t __attribute__((vector_size(16))); +typedef uint32_t uint32x4_t __attribute__((vector_size(16))); + +typedef float float32x4_t __attribute__((vector_size(16))); +typedef double double64x2_t __attribute((vector_size(16))); + +typedef signed long long long64x2_t __attribute((vector_size(16))); +typedef unsigned long long ulong64x2_t __attribute__((vector_size(16))); + +typedef struct ggml_uint8x16x2_t { + uint8x16_t val[2]; +} ggml_uint8x16x2_t; + +inline static ggml_uint8x16x2_t ggml_vec_xl_u8x2(const uint8_t * ptr) { + ggml_uint8x16x2_t res; + + res.val[0] = vec_xl( 0, ptr); + res.val[1] = vec_xl(16, ptr); + + return res; +} + +typedef struct ggml_uint8x16x4_t { + uint8x16_t val[4]; +} ggml_uint8x16x4_t; + +inline static ggml_uint8x16x4_t ggml_vec_xl_u8x4(const uint8_t * ptr) { + ggml_uint8x16x4_t res; + + res.val[0] = vec_xl( 0, ptr); + res.val[1] = vec_xl(16, ptr); + res.val[2] = vec_xl(32, ptr); + res.val[3] = vec_xl(48, ptr); + + return res; +} + +typedef struct ggml_int8x16x4_t { + int8x16_t val[4]; +} ggml_int8x16x4_t; + +inline static ggml_int8x16x4_t ggml_vec_xl_s8x4(const int8_t * ptr) { + ggml_int8x16x4_t res; + + res.val[0] = vec_xl( 0, ptr); + res.val[1] = vec_xl(16, ptr); + res.val[2] = vec_xl(32, ptr); + res.val[3] = vec_xl(48, ptr); + + return res; +} + +typedef struct ggml_int16x8x2_t { + int16x8_t val[2]; +} ggml_int16x8x2_t; + +inline static ggml_int16x8x2_t ggml_vec_xl_s16x2(const int16_t * ptr) { + ggml_int16x8x2_t res; + + res.val[0] = vec_xl( 0, ptr); + res.val[1] = vec_xl(16, ptr); + + return res; +} + +/* + ! WARNING: Very slow. Use vec_perm if possible. Refer to iq4_xs + ! or iq4_nl for example implementation. +*/ +inline static int8x16_t ggml_vec_tbl(int8x16_t a, uint8x16_t b) { + int8x16_t res; + + res[ 0] = a[b[ 0]]; + res[ 1] = a[b[ 1]]; + res[ 2] = a[b[ 2]]; + res[ 3] = a[b[ 3]]; + res[ 4] = a[b[ 4]]; + res[ 5] = a[b[ 5]]; + res[ 6] = a[b[ 6]]; + res[ 7] = a[b[ 7]]; + res[ 8] = a[b[ 8]]; + res[ 9] = a[b[ 9]]; + res[10] = a[b[10]]; + res[11] = a[b[11]]; + res[12] = a[b[12]]; + res[13] = a[b[13]]; + res[14] = a[b[14]]; + res[15] = a[b[15]]; + + return res; +} + +inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) { + const uchar8x16_t v_maske = { 0, 1, 4, 5, 8, 9, 12, 13, + 16, 17, 20, 21, 24, 25, 28, 29 }; + + const int16x8_t v_abo = vec_pack((int32x4_t)a, (int32x4_t)b); + const int16x8_t v_abe = vec_perm(a, b, v_maske); + return v_abo + v_abe; +} + +inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) { + const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b); + return acc + (vec_unpackh(p) + vec_unpackl(p)); +} + +#endif + #if defined(__loongarch_asx) /* float type data load instructions */ static __m128 __lsx_vreplfr2vr_s(const float val) { diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 14ba288fe19..d0c407bd6eb 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -1011,6 +1011,38 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); } +#elif defined(__VXE__) || defined(__VXE2__) + for (int i = 0; i < nb; i++) { + __vector float srcv [8]; + __vector float asrcv[8]; + __vector float amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); + for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(vec_extract(amaxv[0], 0), + vec_extract(amaxv[0], 1)), + MAX(vec_extract(amaxv[0], 2), + vec_extract(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f / d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const __vector float v = vec_mul(srcv[j], vec_splats(id)); + const __vector int32_t vi = vec_signed(v); + + y[i].qs[4*j + 0] = vec_extract(vi, 0); + y[i].qs[4*j + 1] = vec_extract(vi, 1); + y[i].qs[4*j + 2] = vec_extract(vi, 2); + y[i].qs[4*j + 3] = vec_extract(vi, 3); + } + } #else GGML_UNUSED(nb); // scalar @@ -1337,6 +1369,44 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0); __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); } +#elif defined(__VXE__) || defined(__VXE2__) + for (int i = 0; i < nb; i++) { + __vector float srcv [8]; + __vector float asrcv[8]; + __vector float amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); + for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(vec_extract(amaxv[0], 0), + vec_extract(amaxv[0], 1)), + MAX(vec_extract(amaxv[0], 2), + vec_extract(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f / d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + __vector int32_t acc = vec_splats(0); + + for (int j = 0; j < 8; j++) { + const __vector float v = vec_mul(srcv[j], vec_splats(id)); + const __vector int32_t vi = vec_signed(v); + + y[i].qs[4*j + 0] = vec_extract(vi, 0); + y[i].qs[4*j + 1] = vec_extract(vi, 1); + y[i].qs[4*j + 2] = vec_extract(vi, 2); + y[i].qs[4*j + 3] = vec_extract(vi, 3); + + acc = vec_add(acc, vi); + } + + y[i].s = GGML_FP32_TO_FP16(d * (acc[0] + acc[1] + acc[2] + acc[3])); + } #else GGML_UNUSED(nb); // scalar @@ -2488,6 +2558,37 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r } sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#elif defined(__VXE__) || defined(__VXE2__) + __vector float acc = vec_splats(0.0f); + + const __vector uint8_t v_m = vec_splats((const uint8_t)0x0F); + const __vector int8_t v_s = vec_splats( (const int8_t)0x08); + + for (; ib < nb; ++ib) { + const __vector uint8_t v_x = vec_xl(0, x[ib].qs); + const __vector int8_t v_xl = (const __vector int8_t)(v_x & v_m); + const __vector int8_t v_xh = (const __vector int8_t)(v_x >> 4); + + const __vector int8_t v_xls = vec_sub(v_xl, v_s); + const __vector int8_t v_xhs = vec_sub(v_xh, v_s); + + const __vector int8_t v_yl = vec_xl(0 , y[ib].qs); + const __vector int8_t v_yh = vec_xl(QK8_0/2, y[ib].qs); + + const __vector int16_t v_xylso = vec_mulo(v_xls, v_yl); + const __vector int16_t v_xylse = vec_mule(v_xls, v_yl); + const __vector int16_t v_xyhso = vec_mulo(v_xhs, v_yh); + const __vector int16_t v_xyhse = vec_mule(v_xhs, v_yh); + + __vector int16_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_); + + const __vector float v_xy = vec_float(vec_unpackh(v_xy_)); + const __vector float v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); + + acc = vec_madd(v_xy, v_d, acc); + } + + sumf = acc[0] + acc[1] + acc[2] + acc[3]; #endif for (; ib < nb; ++ib) { int sumi0 = 0; @@ -2781,6 +2882,35 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r } sumf = hsum_float_8(acc) + summs; +#elif defined(__VXE__) || defined(__VXE2__) + float summs = 0; + float32x4_t acc = vec_splats(0.0f); + + const uint8x16_t v_m = vec_splat_u8(0x0F); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); + + const uint8x16_t v_x = vec_xl(0, x[ib].qs); + const int8x16_t v_xl = (const int8x16_t)(v_x & v_m); + const int8x16_t v_xh = (const int8x16_t)(v_x >> 4); + + const int8x16_t v_yl = vec_xl(0 , y[ib].qs); + const int8x16_t v_yh = vec_xl(QK8_1/2, y[ib].qs); + + const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh); + const float32x4_t v_xy = vec_float(v_xy_); + + const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); + + acc = vec_madd(v_xy, v_d, acc); + } + + sumf = acc[0] + acc[1] + acc[2] + acc[3] + summs; #endif for (; ib < nb; ++ib) { int sumi0 = 0; @@ -3915,6 +4045,27 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r } sumf = hsum_float_8(acc); +#elif defined(__VXE__) || defined(__VXE2__) + __vector float acc = vec_splats(0.0f); + +#pragma GCC unroll 8 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + const int8x16_t v_xl = vec_xl(0 , x[ib].qs); + const int8x16_t v_xh = vec_xl(QK8_0/2, x[ib].qs); + const int8x16_t v_yl = vec_xl(0 , y[ib].qs); + const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs); + + const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh); + const float32x4_t v_xy = vec_float(v_xy_); + const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); + + acc = vec_madd(v_xy, v_d, acc); + } + + sumf = acc[0] + acc[1] + acc[2] + acc[3]; #endif for (; ib < nb; ++ib) { int sumi = 0; @@ -6797,6 +6948,77 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = hsum_float_8(acc) + ((v4f32)acc_m)[0]; +#elif defined(__VXE__) || defined(__VXE2__) + const uint8x16_t v_lm = vec_splat_u8(0x0F); + const int32x4_t v_z = vec_splat_s32(0); + + uint8x16_t v_x[2]; + int8x16_t v_xl[2]; + int8x16_t v_y[2]; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums); + const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums); + const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh); + + memcpy(utmp, x[i].scales, 12); + + uint32x4_t v_mins8 = { 0 }; + v_mins8 = vec_insert(utmp[1] & kmask1, v_mins8, 0); + v_mins8 = vec_insert(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), v_mins8, 1); + + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + const int16x8_t v_minsh = (int16x8_t)vec_unpackh((uint8x16_t)v_mins8); + + const int32x4_t v_minso = vec_mulo(v_ysums, v_minsh); + const int32x4_t v_minse = vec_mule(v_ysums, v_minsh); + const int32x4_t v_mins = v_minso + v_minse; + sumf -= dmin * (v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]); + + const uint8_t * scales = (const uint8_t *)utmp; + const uint8_t * restrict x0 = x[i].qs; + const int8_t * restrict y0 = y[i].qs; + + int32_t sumi1 = 0; + int32_t sumi2 = 0; + + for (int j = 0; j < QK_K/64; ++j) { + v_x[0] = vec_xl(0 , x0); + v_x[1] = vec_xl(16, x0); + x0 += 32; + + v_y[0] = vec_xl(0 , y0); + v_y[1] = vec_xl(16, y0); + y0 += 32; + + v_xl[0] = (int8x16_t)vec_and(v_x[0], v_lm); + v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm); + + const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]); + sumi1 += (p1[0] + p1[1] + p1[2] + p1[3]) * scales[2*j+0]; + + v_y[0] = vec_xl(0 , y0); + v_y[1] = vec_xl(16, y0); + y0 += 32; + + v_xl[0] = (int8x16_t)vec_sr(v_x[0], 4); + v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4); + + const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]); + sumi2 += (p2[0] + p2[1] + p2[2] + p2[3]) * scales[2*j+1]; + } + + sumf += d * (sumi1 + sumi2); + } + + *s = sumf; #else const uint8_t * scales = (const uint8_t*)&utmp[0]; @@ -7526,7 +7748,94 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4)); *s = hsum_float_8(acc) + ((v4f32)acc_m)[0]; +#elif defined(__VXE__) || defined(__VXE2__) + const uint8x16_t v_lm = vec_splat_u8(0x0F); + const uint8x16_t v_1m = vec_splat_u8(0x01); + const uint8x16_t v_2m = vec_splat_u8(0x02); + + const int32x4_t v_z = vec_splat_s32(0); + + const uchar8x16_t v_minsm = { + 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF + }; + + int8x16_t q5b[4]; + uint8x16_t q5h[4]; + + uint8x16_t v_xl[2]; + uint8x16_t v_xh[2]; + int8x16_t v_y[4]; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums); + const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums); + const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8x16_t v_mins16 = vec_xl(0, (const uint8_t *)utmp); + const uint8x16_t v_mins8 = vec_perm(v_mins16, v_mins16, v_minsm); + const int16x8_t v_minsh = (int16x8_t)vec_unpackh(v_mins8); + + const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh); + const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh); + const int32x4_t v_mins = vec_add(v_minsho, v_minshe); + const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]; + + const uint8_t * scales = (const uint8_t *)utmp; + const uint8_t * restrict x0l = x[i].qs; + const uint8_t * restrict x0h = x[i].qh; + const int8_t * restrict y0 = y[i].qs; + + v_xh[0] = vec_xl(0 , x0h); + v_xh[1] = vec_xl(16, x0h); + + int32_t sumi = 0; + for (int j = 0; j < QK_K/64; ++j) { + v_xl[0] = vec_xl(0 , x0l); + v_xl[1] = vec_xl(16, x0l); + x0l += 32; + + v_y[0] = vec_xl(0 , y0); + v_y[1] = vec_xl(16, y0); + v_y[2] = vec_xl(32, y0); + v_y[3] = vec_xl(48, y0); + y0 += 64; + q5h[0] = vec_sl(vec_and(v_1m, v_xh[0]), 4); + q5h[1] = vec_sl(vec_and(v_1m, v_xh[1]), 4); + q5h[2] = vec_sl(vec_and(v_2m, v_xh[0]), 3); + q5h[3] = vec_sl(vec_and(v_2m, v_xh[1]), 3); + v_xh[0] = vec_sr(v_xh[0], 2); + v_xh[1] = vec_sr(v_xh[1], 2); + + q5b[0] = (int8x16_t)vec_or(vec_and(v_xl[0], v_lm), q5h[0]); + q5b[1] = (int8x16_t)vec_or(vec_and(v_xl[1], v_lm), q5h[1]); + q5b[2] = (int8x16_t)vec_or(vec_sr(v_xl[0], 4), q5h[2]); + q5b[3] = (int8x16_t)vec_or(vec_sr(v_xl[1], 4), q5h[3]); + + int32x4_t sumi0 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]); + int32x4_t sumi1 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]); + + sumi += (sumi0[0] + sumi0[1] + sumi0[2] + sumi0[3]) * *scales++; + sumi += (sumi1[0] + sumi1[1] + sumi1[2] + sumi1[3]) * *scales++; + } + + sumf += d * sumi - dmin * mins; + } + + *s = sumf; #else const uint8_t * scales = (const uint8_t*)&utmp[0]; @@ -8243,7 +8552,130 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r } *s = hsum_float_8(acc); +#elif defined(__VXE__) || defined(__VXE2__) + float sum = 0; + + // Lower 4-bit and upper 2-bit masks + const uint8x16_t v_lm = vec_splat_u8(0x0F); + const uint8x16_t v_um = vec_splat_u8(0x03); + + const int32x4_t v_z = vec_splat_s32(0); + + int8x16_t q6b[4]; + uint8x16_t q6h[4]; + + uint8x16_t v_xl[4]; + uint8x16_t v_xh[2]; + int8x16_t v_y[4]; + + for (int i = 0; i < nb; ++i) { + const float d_all = GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict x0l = x[i].ql; + const uint8_t * restrict x0h = x[i].qh; + const int8_t * restrict y0 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums); + const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums); + + const int8x16_t v_scale = vec_xl(0, scale); + const int16x8_t v_scalel = vec_unpackh(v_scale); + const int16x8_t v_scaleh = vec_unpackl(v_scale); + + const int32x4_t v_minslo = vec_mulo(v_ysumsl, v_scalel); + const int32x4_t v_minsle = vec_mule(v_ysumsl, v_scalel); + const int32x4_t v_minsho = vec_mulo(v_ysumsh, v_scaleh); + const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh); + const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe; + + const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]; + + int32_t isum = 0; + for (int j = 0; j < QK_K/128; ++j) { + // Load model upper 2 bits + v_xh[0] = vec_xl(0 , x0h); + v_xh[1] = vec_xl(16, x0h); + x0h += 32; + + // Load model lower 4 bits + v_xl[0] = vec_xl(0 , x0l); + v_xl[1] = vec_xl(16, x0l); + v_xl[2] = vec_xl(32, x0l); + v_xl[3] = vec_xl(48, x0l); + x0l += 64; + + // Load activation quants + v_y[0] = vec_xl(0 , y0); + v_y[1] = vec_xl(16, y0); + v_y[2] = vec_xl(32, y0); + v_y[3] = vec_xl(48, y0); + y0 += 64; + + q6h[0] = vec_sl(vec_and(v_um, v_xh[0]), 4); + q6h[1] = vec_sl(vec_and(v_um, v_xh[1]), 4); + uint8x16_t shifted = vec_sr(v_xh[0], 2); + q6h[2] = vec_sl(vec_and(v_um, shifted), 4); + shifted = vec_sr(v_xh[1], 2); + q6h[3] = vec_sl(vec_and(v_um, shifted), 4); + + q6b[0] = (int8x16_t)(vec_or(vec_and(v_xl[0], v_lm), q6h[0])); + q6b[1] = (int8x16_t)(vec_or(vec_and(v_xl[1], v_lm), q6h[1])); + q6b[2] = (int8x16_t)(vec_or(vec_and(v_xl[2], v_lm), q6h[2])); + q6b[3] = (int8x16_t)(vec_or(vec_and(v_xl[3], v_lm), q6h[3])); + + int32x4_t summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]); + int32x4_t summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]); + int32x4_t summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]); + int32x4_t summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]); + + isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] + + (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] + + (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] + + (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3]; + + scale += 4; + + // Load activation quants + v_y[0] = vec_xl(0 , y0); + v_y[1] = vec_xl(16, y0); + v_y[2] = vec_xl(32, y0); + v_y[3] = vec_xl(48, y0); + y0 += 64; + + shifted = vec_sr(v_xh[0], 4); + q6h[0] = vec_sl(vec_and(v_um, shifted), 4); + shifted = vec_sr(v_xh[1], 4); + q6h[1] = vec_sl(vec_and(v_um, shifted), 4); + shifted = vec_sr(v_xh[0], 6); + q6h[2] = vec_sl(vec_and(v_um, shifted), 4); + shifted = vec_sr(v_xh[1], 6); + q6h[3] = vec_sl(vec_and(v_um, shifted), 4); + + q6b[0] = (int8x16_t)(vec_or(vec_sr(v_xl[0], 4), q6h[0])); + q6b[1] = (int8x16_t)(vec_or(vec_sr(v_xl[1], 4), q6h[1])); + q6b[2] = (int8x16_t)(vec_or(vec_sr(v_xl[2], 4), q6h[2])); + q6b[3] = (int8x16_t)(vec_or(vec_sr(v_xl[3], 4), q6h[3])); + + summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]); + summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]); + summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]); + summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]); + + isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] + + (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] + + (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] + + (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3]; + + scale += 4; + } + + sum += d_all * y[i].d * (isum - 32 * mins); + } + + *s = sum; #else int8_t aux8[QK_K]; @@ -8604,7 +9036,57 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void } *s = 0.125f * hsum_float_8(accumf); - +//#elif defined(__VXE__) || defined(__VXE2__) +// const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; +// +// uint32_t aux32[4]; +// const uint8_t * aux8 = (const uint8_t *)aux32; +// +// float sumf = 0; +// +// for (int i = 0; i < nb; ++i) { +// const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; +// const uint16_t * restrict q2 = x[i].qs; +// const int8_t * restrict q8 = y[i].qs; +// +// float sumf1 = 0, sumf2 = 0; +// +// for (int ib32 = 0; ib32 < QK_K/32; ib += 2) { +// int8x16_t q8b0 = vec_xl( 0, q8); +// int8x16_t qb81 = vec_xl(16, q8); +// int8x16_t q8b2 = vec_xl(32, q8); +// int8x16_t q8b3 = vec_xl(48, q8); +// q8 += 64; +// +// memcpy(aux32, q2, 4 * sizeof(uint32_t)); +// q2 += 8; +// +// int8x16_t q2u0 = { *(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1]) }; +// int8x16_t q2u1 = { *(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3]) }; +// int8x16_t q2u2 = { *(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9]) }; +// int8x16_t q2u3 = { *(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11]) }; +// +// int8x16_t q2s0 = { *(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127)) }; +// int8x16_t q2s1 = { *(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127)) }; +// int8x16_t q2s2 = { *(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127)) }; +// int8x16_t q2s3 = { *(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127)) }; +// +// q2u0 = vec_mul(q2u0, q2s0); +// q2u1 = vec_mul(q2u1, q2s1); +// q2u2 = vec_mul(q2u2, q2s2); +// q2u3 = vec_mul(q2u3, q2s3); +// +// const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u0, q8b0), q2u1, q8b1); +// const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u2, q8b2), q2u3, q8b3); +// +// sumf1 += (p1[0] + p1[1] + p1[2] + p1[3]) * (0.5f + (aux32[1] >> 28)); +// sumf2 += (p2[0] + p2[1] + p2[2] + p2[3]) * (0.5f + (aux32[3] >> 28)); +// } +// +// sumf += d * (sumf1 + sumf2); +// } +// +// *s = 0.25f * sumf; #else uint32_t aux32[2]; @@ -11365,6 +11847,27 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2)); +#elif defined(__VXE__) || defined(__VXE2__) + const int8x16_t v_k = vec_xl(0, kvalues_iq4nl); + const uint8x16_t v_m = vec_splat_u8(0x0F); + + for (; ib < nb; ++ib) { + const block_iq4_nl * restrict x0 = &x[ib]; + const block_q8_0 * restrict y0 = &y[ib]; + + const uint8x16_t v_x = vec_xl(0, x0->qs); + int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m); + int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4); + + v_xl = vec_perm(v_k, v_k, (uchar8x16_t)v_xl); + v_xh = vec_perm(v_k, v_k, (uchar8x16_t)v_xh); + + const int8x16_t v_yl = vec_xl(0 , y0->qs); + const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs); + const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh); + + sumf += GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d) * (v_xy[0] + v_xy[1] + v_xy[2] + v_xy[3]); + } #endif for (; ib < nb; ++ib) { const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d); @@ -11643,6 +12146,56 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * } *s = hsum_float_8(accum); +#elif defined(__VXE__) || defined(__VXE2__) + const int8x16_t v_k = vec_xl(0, kvalues_iq4nl); + const uint8x16_t v_m = vec_splat_u8(0x0F); + + float sumf = 0; + + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * restrict q4 = x[ibl].qs; + const int8_t * restrict q8 = y[ibl].qs; + + uint16_t h = x[ibl].scales_h; + + int sumi1 = 0, sumi2 = 0; + for (int ib = 0; ib < QK_K/64; ++ib) { + const uint8x16_t v_x0 = vec_xl(0 , q4); + const uint8x16_t v_x1 = vec_xl(QK4_NL/2, q4); + q4 += 32; + + int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m); + int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4); + int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m); + int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4); + + v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l); + v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h); + v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l); + v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h); + + const int8x16_t v_y0 = vec_xl( 0, q8); + const int8x16_t v_y1 = vec_xl(16, q8); + const int8x16_t v_y2 = vec_xl(32, q8); + const int8x16_t v_y3 = vec_xl(48, q8); + q8 += 64; + + int32x4_t vsumi0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0l, v_y0), v_x0h, v_y1); + int32x4_t vsumi1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1l, v_y2), v_x1h, v_y3); + + int ls1 = ((x[ibl].scales_l[ib] & 0xF) | ((h << 4) & 0x30)) - 32; + int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; + + h >>= 4; + + sumi1 += (vsumi0[0] + vsumi0[1] + vsumi0[2] + vsumi0[3]) * ls1; + sumi2 += (vsumi1[0] + vsumi1[1] + vsumi1[2] + vsumi1[3]) * ls2; + } + + sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); + } + + *s = sumf; #else float sumf = 0; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index f27b981715a..723253495a7 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -237,6 +237,8 @@ typedef pthread_t ggml_thread_t; #else #if defined(__POWER9_VECTOR__) #define CACHE_LINE_SIZE 128 +#elif defined(__VXE__) || defined(__VXE2__) +#define CACHE_LINE_SIZE 256 #else #define CACHE_LINE_SIZE 64 #endif @@ -1211,6 +1213,87 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE +#elif defined(__VXE__) || defined(__VXE2__) + +#define GGML_SIMD + +// F32 s390x + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 __vector float +#define GGML_F32x4_ZERO vec_splats(0.0f) +#define GGML_F32x4_SET1 vec_splats +#define GGML_F32x4_LOAD(p) vec_xl(0, p) +#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p) +#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a) +#define GGML_F32x4_ADD vec_add +#define GGML_F32x4_MUL vec_mul +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset + i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset + i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset + i]); \ + } \ + res = vec_extract(x[0], 0) + \ + vec_extract(x[0], 1) + \ + vec_extract(x[0], 2) + \ + vec_extract(x[0], 3); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 s390x +#define GGML_F16_STEP GGML_F32_STEP +#define GGML_F16_EPR GGML_F32_EPR + +static inline __vector float __lzs_f16cx4_load(const ggml_fp16_t * x) { + float tmp[4]; + + for (int i = 0; i < 4; i++) { + tmp[i] = GGML_FP16_TO_FP32(x[i]); + } + + return vec_xl(0, tmp); +} + +static inline void __lzs_f16cx4_store(ggml_fp16_t * x, __vector float y) { + float arr[4]; + + vec_xst(y, 0, arr); + + for (int i = 0; i < 4; i++) { + x[i] = GGML_FP32_TO_FP16(arr[i]); + } +} + +#define GGML_F16_VEC GGML_F32x4 +#define GGML_F16_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F16_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F16_VEC_LOAD(p, i) __lzs_f16cx4_load(p) +#define GGML_F16_VEC_STORE(p, r, i) __lzs_f16cx4_store(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F32x4_FMA +#define GGML_F16_VEC_ADD GGML_F32x4_ADD +#define GGML_F16_VEC_MUL GGML_F32x4_MUL +#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE + #endif // GGML_F32_ARR / GGML_F16_ARR @@ -14419,6 +14502,14 @@ int ggml_cpu_has_vsx(void) { #endif } +int ggml_cpu_has_vxe(void) { +#if defined(__VXE__) || defined(__VXE2__) + return 1; +#else + return 0; +#endif +} + int ggml_cpu_has_neon(void) { #if defined(__ARM_ARCH) && defined(__ARM_NEON) return ggml_arm_arch_features.has_neon; diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index d0ae10ee376..a84203f29f2 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -557,6 +557,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r if (ggml_cpu_has_vsx()) { features.push_back({ "VSX", "1" }); } + if (ggml_cpu_has_vxe()) { + features.push_back({ "VXE", "1" }); + } if (ggml_cpu_has_wasm_simd()) { features.push_back({ "WASM_SIMD", "1" }); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e9f3420c294..7fc06724ebd 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -240,7 +240,11 @@ void ggml_log_callback_default(enum ggml_log_level level, const char * text, voi void * ggml_aligned_malloc(size_t size) { +#if defined(__s390x__) + const int alignment = 256; +#else const int alignment = 64; +#endif #if defined(_MSC_VER) || defined(__MINGW32__) return _aligned_malloc(size, alignment); From 56d1ec4b61b273cee93e4063feb0753f46e47451 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Mon, 24 Feb 2025 15:48:25 +0530 Subject: [PATCH 53/58] SYCL: Fix GGML_SYCL_DEBUG macro (llama/11995) --- ggml/src/ggml-sycl/common.hpp | 2 +- ggml/src/ggml-sycl/ggml-sycl.cpp | 3 ++- ggml/src/ggml-sycl/softmax.cpp | 3 +++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index abad847ca81..a5cab5065fc 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -35,7 +35,7 @@ void* ggml_sycl_host_malloc(size_t size); void ggml_sycl_host_free(void* ptr); -static int g_ggml_sycl_debug = 0; +extern int g_ggml_sycl_debug; #define GGML_SYCL_DEBUG(...) \ do { \ if (g_ggml_sycl_debug) \ diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 3d24d216548..d4c97ad17b8 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -41,6 +41,7 @@ #include "ggml-sycl/gemm.hpp" static bool g_sycl_loaded = false; +int g_ggml_sycl_debug = 0; static ggml_sycl_device_info ggml_sycl_init() { ggml_sycl_device_info info = {}; @@ -157,8 +158,8 @@ static void ggml_check_sycl() try { static bool initialized = false; if (!initialized) { - GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); + GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); GGML_LOG_INFO("GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); #if defined(GGML_SYCL_FORCE_MMQ) GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: yes\n"); diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index 563e0655f55..eb20bd251e1 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -249,13 +249,16 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) { const sycl::half * src1_dd = static_cast(dst->src[1]->data); + GGML_SYCL_DEBUG("%s: F16 mask\n", __func__); soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); } else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) { const float * src1_dd = static_cast(dst->src[1]->data); + GGML_SYCL_DEBUG("%s: F32 mask\n", __func__); soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); } else { /* mask unavailable */ + GGML_SYCL_DEBUG("%s: No mask\n", __func__); soft_max_f32_sycl(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); } } From 8ba9b117d42a87b80796b50c07497dc071dbc84f Mon Sep 17 00:00:00 2001 From: Neo Zhang Jianyu Date: Mon, 24 Feb 2025 22:33:23 +0800 Subject: [PATCH 54/58] Optimize mul_mat for Q4_0 on Intel GPU (llama/12035) * opt performance by reorder for Intel GPU * detect hw type and save opt feature, and print opt feature * correct name * support optimize graph once when compute graph, record the opt status in tensor->extra, make CI passed * add env variable GGML_SYCL_DISABLE_OPT for debug * use syclex::architecture replace the custom hw define, update the guide for GGML_SYCL_DISABLE_OPT * add performance data * mv getrows functions to separeted files * fix global variables --------- Co-authored-by: arthw <14088817+arthw@users.noreply.github.com> --- ggml/src/ggml-sycl/CMakeLists.txt | 2 + ggml/src/ggml-sycl/common.cpp | 17 ++ ggml/src/ggml-sycl/common.hpp | 58 ++++- ggml/src/ggml-sycl/convert.cpp | 37 ++- ggml/src/ggml-sycl/convert.hpp | 4 +- ggml/src/ggml-sycl/dequantize.hpp | 55 +++++ ggml/src/ggml-sycl/dmmv.cpp | 140 +++++++++++- ggml/src/ggml-sycl/getrows.cpp | 308 +++++++++++++++++++++++++ ggml/src/ggml-sycl/getrows.hpp | 23 ++ ggml/src/ggml-sycl/ggml-sycl.cpp | 369 ++++++++++-------------------- ggml/src/ggml-sycl/sycl_hw.cpp | 13 ++ ggml/src/ggml-sycl/sycl_hw.hpp | 23 ++ 12 files changed, 787 insertions(+), 262 deletions(-) create mode 100644 ggml/src/ggml-sycl/getrows.cpp create mode 100644 ggml/src/ggml-sycl/getrows.hpp create mode 100644 ggml/src/ggml-sycl/sycl_hw.cpp create mode 100644 ggml/src/ggml-sycl/sycl_hw.hpp diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 3579a311aac..3ad044432a2 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -1,3 +1,5 @@ +message(STATUS "GGML_SYCL_TARGET=${GGML_SYCL_TARGET}") + if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$") message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD") endif() diff --git a/ggml/src/ggml-sycl/common.cpp b/ggml/src/ggml-sycl/common.cpp index 022e7b7637b..9069c47865f 100644 --- a/ggml/src/ggml-sycl/common.cpp +++ b/ggml/src/ggml-sycl/common.cpp @@ -99,3 +99,20 @@ catch (sycl::exception const &exc) { << ", line:" << __LINE__ << std::endl; std::exit(1); } + + +void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector streams) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { + if (extra->events[i][is] != nullptr) { + SYCL_CHECK(CHECK_TRY_ERROR(dpct::destroy_event(extra->events[i][is]))); + } + } + if (extra->data_device[i] != nullptr && streams.size()>0) { + ggml_sycl_set_device(i); + SYCL_CHECK( + CHECK_TRY_ERROR(sycl::free(extra->data_device[i], *(streams[i])))); + } + } + delete extra; +} diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index a5cab5065fc..7c503a1b10e 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -19,6 +19,9 @@ #include "dpct/helper.hpp" #include "ggml-sycl.h" #include "presets.hpp" +#include "sycl_hw.hpp" + + #if GGML_SYCL_DNNL #include "dnnl.hpp" #include "dnnl_sycl.hpp" @@ -35,7 +38,10 @@ void* ggml_sycl_host_malloc(size_t size); void ggml_sycl_host_free(void* ptr); + extern int g_ggml_sycl_debug; +extern int g_ggml_sycl_disable_optimize; + #define GGML_SYCL_DEBUG(...) \ do { \ if (g_ggml_sycl_debug) \ @@ -182,18 +188,24 @@ inline dpct::err0 ggml_sycl_set_device(const int device) try { } ////////////////////// +struct optimize_feature { + bool reorder=false; +}; + +struct sycl_device_info { + int cc; // compute capability + // int nsm; // number of streaming multiprocessors + // size_t smpb; // max. shared memory per block + bool vmm; // virtual memory support + size_t total_vram; + sycl_hw_info hw_info; + optimize_feature opt_feature; +}; + struct ggml_sycl_device_info { int device_count; - struct sycl_device_info { - int cc; // compute capability - // int nsm; // number of streaming multiprocessors - // size_t smpb; // max. shared memory per block - bool vmm; // virtual memory support - size_t total_vram; - }; - sycl_device_info devices[GGML_SYCL_MAX_DEVICES] = {}; std::array default_tensor_split = {}; @@ -260,17 +272,46 @@ struct ggml_tensor_extra_gpu { // tensors dpct::event_ptr events[GGML_SYCL_MAX_DEVICES] [GGML_SYCL_MAX_STREAMS]; // events for synchronizing multiple GPUs + optimize_feature optimized_feature; }; +void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector streams={}); + +inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) { + optimize_feature opt; + + opt.reorder = + (arch == syclex::architecture::intel_gpu_dg1 || + arch == syclex::architecture::intel_gpu_acm_g10 || + arch == syclex::architecture::intel_gpu_acm_g11 || + arch == syclex::architecture::intel_gpu_acm_g12 || + arch == syclex::architecture::intel_gpu_pvc || + arch == syclex::architecture::intel_gpu_pvc_vg || + arch == syclex::architecture::intel_gpu_mtl_u || + arch == syclex::architecture::intel_gpu_mtl_s || + arch == syclex::architecture::intel_gpu_mtl_h || + arch == syclex::architecture::intel_gpu_arl_u || + arch == syclex::architecture::intel_gpu_arl_s || + arch == syclex::architecture::intel_gpu_arl_h || + arch == syclex::architecture::intel_gpu_bmg_g21 || + arch == syclex::architecture::intel_gpu_lnl_m + ); + + return opt; +} + struct ggml_backend_sycl_context { int device; std::string name; + optimize_feature opt_feature; + bool optimized_graph=false; queue_ptr qptrs[GGML_SYCL_MAX_DEVICES][GGML_SYCL_MAX_STREAMS] = { { nullptr } }; explicit ggml_backend_sycl_context(int device) : device(device), name(GGML_SYCL_NAME + std::to_string(device)) { + opt_feature = ggml_sycl_info().devices[device].opt_feature; } queue_ptr stream(int device, int stream) { @@ -680,5 +721,4 @@ bool gpu_has_xmx(sycl::device &dev); void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const ggml_sycl_op_flatten_t op); - #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 05b01db2d8b..86b200e0703 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -125,6 +125,25 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k, } } +template +static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + int constexpr WARP_K = WARP_SIZE * QK4_0; + const int n_warp = (k + WARP_K - 1) / WARP_K; + GGML_ASSERT(k % 2 == 0); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * + sycl::range<3>(1, 1, WARP_SIZE), + sycl::range<3>(1, 1, WARP_SIZE)), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{ + dequantize_block_q4_0_reorder(vx, y, k, item_ct1); + }); + +} + template static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -452,10 +471,15 @@ static void convert_unary_sycl(const void *__restrict__ vx, } } -to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) { +to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst) { switch (type) { case GGML_TYPE_Q4_0: - return dequantize_block_sycl; + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q4_0_sycl_reorder; + } else { + return dequantize_block_sycl; + } case GGML_TYPE_Q4_1: return dequantize_block_sycl; case GGML_TYPE_Q5_0: @@ -499,10 +523,15 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) { } } -to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) { +to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { switch (type) { case GGML_TYPE_Q4_0: - return dequantize_row_q4_0_sycl; + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q4_0_sycl_reorder; + } else { + return dequantize_row_q4_0_sycl; + } case GGML_TYPE_Q4_1: return dequantize_row_q4_1_sycl; case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp index 0ce2874aaae..355dae22b40 100644 --- a/ggml/src/ggml-sycl/convert.hpp +++ b/ggml/src/ggml-sycl/convert.hpp @@ -21,7 +21,7 @@ using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y, typedef to_t_sycl_t to_fp32_sycl_t; typedef to_t_sycl_t to_fp16_sycl_t; -to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type); -to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type); +to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst); +to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst); #endif // GGML_SYCL_CONVERT_HPP diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index b8304c3a274..651c2160d24 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -16,6 +16,8 @@ #include "common.hpp" typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); +typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs, + const int iqs, dfloat2 &v); static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { @@ -40,6 +42,29 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib, #endif // GGML_SYCL_F16 } +static __dpct_inline__ void dequantize_q4_0_reorder(const void *d_ptr, const int64_t ib, const void *qs, + const int iqs, dfloat2 &v) { + // const block_q4_0 * x = (const block_q4_0 *) vx; + + const dfloat d = (const dfloat)*((const sycl::half*)d_ptr+ib); + + const int vui = *((const uint8_t *)qs+iqs); + + v.x() = vui & 0xF; + v.y() = vui >> 4; + +#ifdef GGML_SYCL_F16 + // v = v - {8.0f, 8.0f}; + // v = v * {d, d}; + v.s0() = (v.s0() - 8.0f) * d; + v.s1() = (v.s1() - 8.0f) * d; + +#else + v.x() = (v.x() - 8.0f) * d; + v.y() = (v.y() - 8.0f) * d; +#endif // GGML_SYCL_F16 +} + static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q4_1 * x = (const block_q4_1 *) vx; @@ -167,6 +192,36 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri } } +template +static void dequantize_block_q4_0_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, + const sycl::nd_item<3> &item_ct1) { + + const int64_t i = item_ct1.get_group(2); + auto k=nb32; + // assume 32 threads + const int64_t tid = item_ct1.get_local_id(2); + const int lane_ib = i * WARP_SIZE + tid; + + if (lane_ib >= k / QK4_0) { + return; + } + + dst_t * y_ptr = yy + lane_ib * QK4_0; + + auto qs = (const uint8_t*)vx + lane_ib * QK4_0 / 2; + auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k / 2) + lane_ib; + + const float d = float(*s_ptr); + +#pragma unroll + for (int l = 0; l < QK4_0 / 2; ++l) { + int vq = qs[l]; + y_ptr[l + 0] = d * ((vq & 0xF) - 8); + y_ptr[l + 16] = d * ((vq >> 4) - 8); + } + +} + template static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, const sycl::nd_item<3> &item_ct1) { diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 0d097357ce7..99d3859de89 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -3,7 +3,6 @@ #include "dequantize.hpp" #include "presets.hpp" - static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const sycl::half *x = (const sycl::half *)vx; @@ -91,6 +90,112 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * } } +template +static void dequantize_mul_mat_vec_reorder(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, + const sycl::nd_item<3> &item_ct1) { + // qk = quantized weights per x block + // qr = number of quantized weights per data value in x block + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int tid = item_ct1.get_local_id(2); + + + const int ncols_left = ncols % (QK4_0*WARP_SIZE); + const int ncols_align = ncols - ncols_left; + const int iter_stride = 8*2*GGML_SYCL_DMMV_X; + const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter //64/16=4, 512/16/2= 16 + const int y_offset = qr == 1 ? 1 : qk/2; + +// partial sum for each thread +#ifdef GGML_SYCL_F16 + sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics +#else + float tmp = 0.0f; +#endif // GGML_SYCL_F16 + const char *d_ptr = (const char*)vx+ncols*nrows/2; + int i=0; + for (i = 0; i < ncols_align; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/qk; // x block index + const int iqs = (col%qk)/qr; // x quant index + const int iybs = col - col%qk; // y block start index + +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + // process 2 vals per j iter + + // dequantize + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val + dfloat2 v; + dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v); + + // matrix multiplication + // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 +#ifdef GGML_SYCL_F16 + dfloat2 t1{y[iybs + iqs + j / qr + 0], + y[iybs + iqs + j / qr + y_offset]}; + + tmp += v * t1; +#else + tmp += v.x() * y[iybs + iqs + j / qr + 0]; + tmp += v.y() * y[iybs + iqs + j / qr + y_offset]; +#endif // GGML_SYCL_F16 + } + } + + for (; i < ncols; i += iter_stride) { + if (tid>=ncols_left/QK4_0) continue; + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/qk; // x block index + const int iqs = (col%qk)/qr; // x quant index + const int iybs = col - col%qk; // y block start index + +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + // process 2 vals per j iter + + // dequantize + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val + dfloat2 v; + dequantize_kernel_reorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v); + + // matrix multiplication + // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 +#ifdef GGML_SYCL_F16 + dfloat2 t1{y[iybs + iqs + j / qr + 0], + y[iybs + iqs + j / qr + y_offset]}; + + tmp += v * t1; +#else + tmp += v.x() * y[iybs + iqs + j / qr + 0]; + tmp += v.y() * y[iybs + iqs + j / qr + y_offset]; +#endif // GGML_SYCL_F16 + } + } + + // sum up partial sums and write back result + const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2; + for (int mask = mask_start; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { +#ifdef GGML_SYCL_F16 + dst[row] = tmp.x() + tmp.y(); +#else + dst[row] = tmp; +#endif // GGML_SYCL_F16 + } +} + static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y, float *dst, const int ncols, const int nrows, @@ -759,6 +864,28 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa } } +static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec_reorder( + vx, y, dst, ncols, nrows, item_ct1); + }); + } +} + static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y, float *dst, const int ncols, @@ -953,7 +1080,6 @@ void ggml_sycl_op_dequantize_mul_mat_vec( const int64_t ne00 = src0->ne[0]; const int64_t row_diff = row_high - row_low; - GGML_ASSERT(src1->type == GGML_TYPE_F32); // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics #ifdef GGML_SYCL_F16 @@ -967,7 +1093,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec( if (src1_convert_f16) { src1_dfloat = src1_dfloat_a.alloc(ne00); - const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type); + const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); GGML_ASSERT(to_fp16_sycl != nullptr); to_fp16_sycl(src1_ddf_i, src1_dfloat, ne00, stream); } @@ -977,7 +1103,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec( switch (src0->type) { case GGML_TYPE_Q4_0: - dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu*)dst->src[0]->extra && + ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { + dequantize_mul_mat_vec_q4_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + } else { + dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + } break; case GGML_TYPE_Q4_1: dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); @@ -1020,4 +1151,5 @@ void ggml_sycl_op_dequantize_mul_mat_vec( GGML_UNUSED(src1_ddq_i); GGML_UNUSED(src1_ncols); GGML_UNUSED(src1_padded_row_size); + GGML_UNUSED(ctx); } diff --git a/ggml/src/ggml-sycl/getrows.cpp b/ggml/src/ggml-sycl/getrows.cpp new file mode 100644 index 00000000000..51c19f6b3b9 --- /dev/null +++ b/ggml/src/ggml-sycl/getrows.cpp @@ -0,0 +1,308 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#include "ggml-impl.h" +#include "common.hpp" +#include "dequantize.hpp" +#include "getrows.hpp" + + +template +static void k_get_rows( + const void * src0, const int32_t * src1, dst_t * dst, + int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ + /*size_t s0,*/ size_t s1, size_t s2, size_t s3, + /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, + size_t s10, size_t s11, size_t s12, + const sycl::nd_item<3> &item_ct1/*, size_t s13*/) { + + const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) + + item_ct1.get_local_id(2)) * + 2; + const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1); + const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) / + ne12; + const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) % + ne12; + + if (i00 >= ne00) { + return; + } + + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03; + + const int ib = i00/qk; // block index + const int iqs = (i00%qk)/qr; // quant index + const int iybs = i00 - i00%qk; // dst block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + dfloat2 v; + dequantize_kernel(src0_row, ib, iqs, v); + + dst_row[iybs + iqs + 0] = v.x(); + dst_row[iybs + iqs + y_offset] = v.y(); +} + +template +static void k_get_rows_reorder( + const void * src0, const void *src0_dq, const int32_t * src1, dst_t * dst, + int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ + /*size_t s0,*/ size_t s1, size_t s2, size_t s3, + /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, + size_t s10, size_t s11, size_t s12, + const sycl::nd_item<3> &item_ct1/*, size_t s13*/) { + + const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) + + item_ct1.get_local_id(2)) * + 2; + const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1); + const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) / + ne12; + const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) % + ne12; + + if (i00 >= ne00) { + return; + } + auto ncols = ne00; + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + + const int src0_off = i01 * ncols + i00; + const int ib = src0_off / QK4_0; // block index + const int iqs = (i00%qk)/qr; // x quant index + const int iybs = i00 - i00%qk; // dst block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + dfloat2 v; + dequantize_kernel_recorder((const void *)src0_dq, ib, (const void *)src0, src0_off/2, v); + + dst_row[iybs + iqs + 0] = v.x(); + dst_row[iybs + iqs + y_offset] = v.y(); + + GGML_UNUSED(nb01); + GGML_UNUSED(nb02); + GGML_UNUSED(nb03); +} + +template +static void k_get_rows_float( + const src0_t * src0, const int32_t * src1, dst_t * dst, + int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ + /*size_t s0,*/ size_t s1, size_t s2, size_t s3, + /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, + size_t s10, size_t s11, size_t s12, + const sycl::nd_item<3> &item_ct1/*, size_t s13*/) { + + const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + + item_ct1.get_local_id(2); + const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1); + const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) / + ne12; + const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + + item_ct1.get_local_id(0)) % + ne12; + + if (i00 >= ne00) { + return; + } + + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03); + + dst_row[i00] = src0_row[i00]; +} + +template +static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const void *src0_dd, + const int32_t *src1_dd, float *dst_dd, + queue_ptr stream) { + + GGML_TENSOR_BINARY_OP_LOCALS + + const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE); + const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE); + const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x); + + // strides in elements + //const size_t s0 = nb0 / ggml_element_size(dst); + const size_t s1 = nb1 / ggml_element_size(dst); + const size_t s2 = nb2 / ggml_element_size(dst); + const size_t s3 = nb3 / ggml_element_size(dst); + + const size_t s10 = nb10 / ggml_element_size(src1); + const size_t s11 = nb11 / ggml_element_size(src1); + const size_t s12 = nb12 / ggml_element_size(src1); + //const size_t s13 = nb13 / ggml_element_size(src1); + + GGML_ASSERT(ne00 % 2 == 0); + + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_get_rows( + src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, + s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); + }); + + GGML_UNUSED(dst); + GGML_UNUSED(ctx); +} + +template +static void get_rows_sycl_reorder(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const void *src0_dd, + const int32_t *src1_dd, float *dst_dd, + queue_ptr stream) { + + GGML_TENSOR_BINARY_OP_LOCALS + + const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE); + const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE); + const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x); + + // strides in elements + //const size_t s0 = nb0 / ggml_element_size(dst); + const size_t s1 = nb1 / ggml_element_size(dst); + const size_t s2 = nb2 / ggml_element_size(dst); + const size_t s3 = nb3 / ggml_element_size(dst); + + const size_t s10 = nb10 / ggml_element_size(src1); + const size_t s11 = nb11 / ggml_element_size(src1); + const size_t s12 = nb12 / ggml_element_size(src1); + //const size_t s13 = nb13 / ggml_element_size(src1); + + GGML_ASSERT(ne00 % 2 == 0); + + const uint8_t* src0_q = (const uint8_t*)src0_dd; + const size_t ncols = ne00; + const size_t nrows = ne01; + const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{ + k_get_rows_reorder( + src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2, + s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); + }); + + GGML_UNUSED(dst); + GGML_UNUSED(ctx); +} + + +template +static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const src0_t *src0_dd, const int32_t *src1_dd, + float *dst_dd, queue_ptr stream) { + + GGML_TENSOR_BINARY_OP_LOCALS + + const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE); + const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE; + const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x); + + // strides in elements + //const size_t s0 = nb0 / ggml_element_size(dst); + const size_t s1 = nb1 / ggml_element_size(dst); + const size_t s2 = nb2 / ggml_element_size(dst); + const size_t s3 = nb3 / ggml_element_size(dst); + + const size_t s10 = nb10 / ggml_element_size(src1); + const size_t s11 = nb11 / ggml_element_size(src1); + const size_t s12 = nb12 / ggml_element_size(src1); + //const size_t s13 = nb13 / ggml_element_size(src1); + + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, + s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); + }); + } + + GGML_UNUSED(dst); + GGML_UNUSED(ctx); +} + +void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_d, const float *src1_d, + float *dst_d, const queue_ptr &stream) { + + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); + GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); + + const int32_t * src1_i32 = (const int32_t *) src1_d; + + switch (src0->type) { + case GGML_TYPE_F16: + get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d, + src1_i32, dst_d, stream); + break; + case GGML_TYPE_F32: + get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + break; + case GGML_TYPE_Q4_0: + if (ctx.opt_feature.reorder && dst->op == GGML_OP_MUL_MAT) { + get_rows_sycl_reorder(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + } else { + get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + } + break; + case GGML_TYPE_Q4_1: + get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + break; + case GGML_TYPE_Q5_0: + get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + break; + case GGML_TYPE_Q5_1: + get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + break; + case GGML_TYPE_Q8_0: + get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); + break; + default: + // TODO: k-quants + GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type)); + GGML_ABORT("fatal error"); + break; + } +} + diff --git a/ggml/src/ggml-sycl/getrows.hpp b/ggml/src/ggml-sycl/getrows.hpp new file mode 100644 index 00000000000..cdbe6c2f41b --- /dev/null +++ b/ggml/src/ggml-sycl/getrows.hpp @@ -0,0 +1,23 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_GETROWS_HPP +#define GGML_SYCL_GETROWS_HPP + +#include "common.hpp" + +void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_d, const float *src1_d, + float *dst_d, const queue_ptr &stream); + +#endif // GGML_SYCL_GETROWS_HPP diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index d4c97ad17b8..792e0569ca6 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -39,9 +39,12 @@ #include "ggml-sycl/backend.hpp" #include "ggml-sycl/presets.hpp" #include "ggml-sycl/gemm.hpp" +#include "ggml-sycl/sycl_hw.hpp" +#include "ggml-sycl/getrows.hpp" static bool g_sycl_loaded = false; int g_ggml_sycl_debug = 0; +int g_ggml_sycl_disable_optimize = 0; static ggml_sycl_device_info ggml_sycl_init() { ggml_sycl_device_info info = {}; @@ -64,14 +67,18 @@ static ggml_sycl_device_info ggml_sycl_init() { for (int i = 0; i < info.device_count; ++i) { info.devices[i].vmm = 0; dpct::device_info prop; + sycl::device device = dpct::dev_mgr::instance().get_device(i); + SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( - prop, dpct::dev_mgr::instance().get_device(i)))); + prop, device))); info.default_tensor_split[i] = total_vram; total_vram += prop.get_global_mem_size(); info.devices[i].cc = 100 * prop.get_major_version() + 10 * prop.get_minor_version(); + info.devices[i].hw_info = get_device_hw_info(&device); + info.devices[i].opt_feature = check_gpu_optimize_feature(info.devices[i].hw_info.arch); info.max_work_group_sizes[i] = prop.get_max_work_group_size(); } @@ -110,6 +117,27 @@ void print_device_detail(int id, sycl::device &device, std::string device_type) global_mem_size, device.get_info().c_str()); } +void print_device_opt_feature(int device_count) { + GGML_LOG_INFO("SYCL Optimization Feature:\n"); + GGML_LOG_INFO( + "|ID| Device Type|Reorder|\n"); + GGML_LOG_INFO( + "|--|-------------------|-------|\n"); + std::map DeviceNums; + for (int id = 0; id < device_count; ++id) { + sycl::device device = dpct::dev_mgr::instance().get_device(id); + std::string backend_type = get_device_backend_and_type(device); + int type_id = DeviceNums[backend_type]++; + std::stringstream device_type; + device_type << "[" << backend_type << ":" << std::to_string(type_id) + << "]"; + std::string device_type_s = device_type.str(); + device_type_s = std::regex_replace(device_type_s, std::regex("ext_oneapi_"), ""); + GGML_LOG_INFO("|%2d|%19s|%7s|\n", id, device_type_s.c_str(), + ggml_sycl_info().devices[id].opt_feature.reorder ? "Y": "N"); + } + +} void ggml_backend_sycl_print_sycl_devices() { GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n"); int device_count = dpct::dev_mgr::instance().device_count(); @@ -138,6 +166,8 @@ void ggml_backend_sycl_print_sycl_devices() { << "]"; print_device_detail(id, device, device_type.str()); } + + print_device_opt_feature(device_count); } static inline int get_sycl_env(const char *env_name, int default_val) { @@ -159,17 +189,21 @@ static void ggml_check_sycl() try { if (!initialized) { g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); + g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 0); GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); - GGML_LOG_INFO("GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); + GGML_LOG_INFO("Running with Environment Variables:\n"); + GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); + GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize); + GGML_LOG_INFO("Build with Macros:\n"); #if defined(GGML_SYCL_FORCE_MMQ) - GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: yes\n"); + GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n"); #else - GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: no\n"); + GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n"); #endif #if defined(GGML_SYCL_F16) - GGML_LOG_INFO("GGML_SYCL_F16: yes\n"); + GGML_LOG_INFO(" GGML_SYCL_F16: yes\n"); #else - GGML_LOG_INFO("GGML_SYCL_F16: no\n"); + GGML_LOG_INFO(" GGML_SYCL_F16: no\n"); #endif /* NOT REMOVE, keep it for next optimize for XMX. @@ -241,19 +275,27 @@ struct ggml_backend_sycl_buffer_context { void * dev_ptr = nullptr; queue_ptr stream; std::string name; + optimize_feature opt_feature; + std::vector tensor_extras; - ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) : + ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) : device(device), dev_ptr(dev_ptr), stream(stream) { check_allow_gpu_index(device); name = (GGML_SYCL_NAME + std::to_string(device)); + opt_feature = ggml_sycl_info().devices[device].opt_feature; } - ~ggml_backend_sycl_buffer_context() { if (dev_ptr != nullptr) { ggml_sycl_set_device(device); SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream))); } + + //release extra used by tensors + for (ggml_tensor_extra_gpu * extra : tensor_extras) { + release_extra_gpu(extra); + } + } }; @@ -291,6 +333,9 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, return; } + ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; + tensor->extra = extra; + ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx. if (ggml_is_quantized(tensor->type)) { // initialize padding to 0 to avoid possible NaN values @@ -316,7 +361,6 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer, size_t size) try { ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; - ggml_sycl_set_device(ctx->device); auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue()); SYCL_CHECK( @@ -660,32 +704,7 @@ struct ggml_backend_sycl_split_buffer_type_context { struct ggml_backend_sycl_split_buffer_context { ~ggml_backend_sycl_split_buffer_context() try { for (ggml_tensor_extra_gpu * extra : tensor_extras) { - for (int i = 0; i < ggml_sycl_info().device_count; ++i) { - for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { - if (extra->events[i][is] != nullptr) { - /* - DPCT1009:206: SYCL uses exceptions to report errors and - does not use the error codes. The original code was - commented out and a warning string was inserted. You - need to rewrite this code. - */ - SYCL_CHECK(CHECK_TRY_ERROR( - dpct::destroy_event(extra->events[i][is]))); - } - } - if (extra->data_device[i] != nullptr) { - /* - DPCT1009:207: SYCL uses exceptions to report errors and does - not use the error codes. The original code was commented out - and a warning string was inserted. You need to rewrite this - code. - */ - ggml_sycl_set_device(i); - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free( - extra->data_device[i], *(streams[i])))); - } - } - delete extra; + release_extra_gpu(extra, streams); } } catch (sycl::exception const &exc) { @@ -723,7 +742,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; ctx->tensor_extras.push_back(extra); - ctx->streams.push_back(&(dpct::get_current_device().default_queue())); + ctx->streams.push_back(&(dpct::get_current_device().default_queue())); for (int i = 0; i < ggml_sycl_info().device_count; ++i) { int64_t row_low, row_high; @@ -1337,83 +1356,6 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, reinterpret_cast(y[ib].ds.y()) = sum; } -template -static void k_get_rows( - const void * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ - /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ - /*size_t s0,*/ size_t s1, size_t s2, size_t s3, - /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, - size_t s10, size_t s11, size_t s12, - const sycl::nd_item<3> &item_ct1/*, size_t s13*/) { - - const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) + - item_ct1.get_local_id(2)) * - 2; - const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1); - const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + - item_ct1.get_local_id(0)) / - ne12; - const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + - item_ct1.get_local_id(0)) % - ne12; - - if (i00 >= ne00) { - return; - } - - const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03; - - const int ib = i00/qk; // block index - const int iqs = (i00%qk)/qr; // quant index - const int iybs = i00 - i00%qk; // dst block start index - const int y_offset = qr == 1 ? 1 : qk/2; - - // dequantize - dfloat2 v; - dequantize_kernel(src0_row, ib, iqs, v); - - dst_row[iybs + iqs + 0] = v.x(); - dst_row[iybs + iqs + y_offset] = v.y(); -} - -template -static void k_get_rows_float( - const src0_t * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ - /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ - /*size_t s0,*/ size_t s1, size_t s2, size_t s3, - /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, - size_t s10, size_t s11, size_t s12, - const sycl::nd_item<3> &item_ct1/*, size_t s13*/) { - - const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + - item_ct1.get_local_id(2); - const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1); - const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + - item_ct1.get_local_id(0)) / - ne12; - const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + - item_ct1.get_local_id(0)) % - ne12; - - if (i00 >= ne00) { - return; - } - - const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03); - - dst_row[i00] = src0_row[i00]; -} - static void mul_mat_p021_f16_f32( const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y, @@ -1896,81 +1838,6 @@ static void pool2d_nchw_kernel( o_ptr[cur_oh * ow + cur_ow] = res; } -template -static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const void *src0_dd, - const int32_t *src1_dd, float *dst_dd, - queue_ptr stream) { - - GGML_TENSOR_BINARY_OP_LOCALS - - const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE); - const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE); - const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x); - - // strides in elements - //const size_t s0 = nb0 / ggml_element_size(dst); - const size_t s1 = nb1 / ggml_element_size(dst); - const size_t s2 = nb2 / ggml_element_size(dst); - const size_t s3 = nb3 / ggml_element_size(dst); - - const size_t s10 = nb10 / ggml_element_size(src1); - const size_t s11 = nb11 / ggml_element_size(src1); - const size_t s12 = nb12 / ggml_element_size(src1); - //const size_t s13 = nb13 / ggml_element_size(src1); - - GGML_ASSERT(ne00 % 2 == 0); - - stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_get_rows( - src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, - s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); - }); - - GGML_UNUSED(dst); - GGML_UNUSED(ctx); -} - -template -static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const src0_t *src0_dd, const int32_t *src1_dd, - float *dst_dd, queue_ptr stream) { - - GGML_TENSOR_BINARY_OP_LOCALS - - const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE); - const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE; - const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x); - - // strides in elements - //const size_t s0 = nb0 / ggml_element_size(dst); - const size_t s1 = nb1 / ggml_element_size(dst); - const size_t s2 = nb2 / ggml_element_size(dst); - const size_t s3 = nb3 / ggml_element_size(dst); - - const size_t s10 = nb10 / ggml_element_size(src1); - const size_t s11 = nb11 / ggml_element_size(src1); - const size_t s12 = nb12 / ggml_element_size(src1); - //const size_t s13 = nb13 / ggml_element_size(src1); - - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, - s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); - }); - } - - GGML_UNUSED(dst); - GGML_UNUSED(ctx); -} - static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx, const int ky, const int kx_padded, queue_ptr stream) { @@ -2494,52 +2361,6 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_d, const float *src1_d, - float *dst_d, const queue_ptr &stream) { - - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); - GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); - - const int32_t * src1_i32 = (const int32_t *) src1_d; - - switch (src0->type) { - case GGML_TYPE_F16: - get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d, - src1_i32, dst_d, stream); - break; - case GGML_TYPE_F32: - get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); - break; - case GGML_TYPE_Q4_0: - get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); - break; - case GGML_TYPE_Q4_1: - get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); - break; - case GGML_TYPE_Q5_0: - get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); - break; - case GGML_TYPE_Q5_1: - get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); - break; - case GGML_TYPE_Q8_0: - get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream); - break; - default: - // TODO: k-quants - GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type)); - GGML_ABORT("fatal error"); - break; - } -} - - static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_d, const float *src1_d, @@ -2589,11 +2410,10 @@ inline void ggml_sycl_op_mul_mat_sycl( if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { - // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n"); ggml_sycl_pool_alloc src0_as_f16(ctx.pool()); if (src0->type != GGML_TYPE_F16) { - const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type); + const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type, dst); GGML_ASSERT(to_fp16_sycl != nullptr); size_t ne = row_diff*ne00; src0_as_f16.alloc(ne); @@ -2605,7 +2425,7 @@ inline void ggml_sycl_op_mul_mat_sycl( ggml_sycl_pool_alloc src1_as_f16(ctx.pool()); if (src1->type != GGML_TYPE_F16) { - const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type); + const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); GGML_ASSERT(to_fp16_sycl != nullptr); size_t ne = src1_ncols*ne10; src1_as_f16.alloc(ne); @@ -2626,13 +2446,13 @@ inline void ggml_sycl_op_mul_mat_sycl( src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, dst_f16.get(), dpct::library_data_t::real_half, ldc, dpct::library_data_t::real_half))); - const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); #else auto dnnl_stream = ctx.stream_dnnl(stream); DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(), dst_f16.get(), DnnlGemmWrapper::to_dt()); - const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); #endif } @@ -2641,13 +2461,13 @@ inline void ggml_sycl_op_mul_mat_sycl( ggml_sycl_pool_alloc src0_ddq_as_f32(ctx.pool()); ggml_sycl_pool_alloc src1_ddq_as_f32(ctx.pool()); if (src0->type != GGML_TYPE_F32) { - const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst); GGML_ASSERT(to_fp32_sycl != nullptr); src0_ddq_as_f32.alloc(row_diff*ne00); to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream); } if (src1->type != GGML_TYPE_F32) { - const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type, dst); GGML_ASSERT(to_fp32_sycl != nullptr); src1_ddq_as_f32.alloc(src1_ncols*ne10); to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream); @@ -3085,7 +2905,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) { const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0; const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride; - for (int i = 0; i < ggml_sycl_info().device_count; ++i) { if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) { continue; @@ -3393,7 +3212,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, // convert src1 to fp16 ggml_sycl_pool_alloc src1_f16_alloc(ctx.pool()); if (src1->type != GGML_TYPE_F16) { - const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type); + const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); const int64_t ne_src1 = ggml_nelements(src1); src1_f16_alloc.alloc(ne_src1); GGML_ASSERT(to_fp16_sycl != nullptr); @@ -3509,6 +3328,7 @@ bool ggml_sycl_supports_dmmv(enum ggml_type type) { } static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer); int64_t min_compute_capability = INT_MAX; @@ -3570,6 +3390,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); + // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream()); } else if (use_mul_mat_vec_q) { ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); } else if (use_mul_mat_q) { @@ -4251,10 +4072,72 @@ catch (sycl::exception const &exc) { std::exit(1); } +void reorder_qw(char *data_device, const int ncols, const int nrows, + size_t size, size_t offset, dpct::queue_ptr stream) { + auto tmp_buf = sycl::malloc_shared(size, *stream); + SYCL_CHECK( + CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size) + .wait())); + GGML_ASSERT((size % sizeof(block_q4_0) == 0)); + GGML_ASSERT((offset % sizeof(block_q4_0) == 0)); + int offset_blks = offset / sizeof(block_q4_0); + auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;; + auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks; + + stream->parallel_for( + size / sizeof(block_q4_0), + [=](auto i) [[intel::reqd_sub_group_size(WARP_SIZE)]] { + const block_q4_0* x = (const block_q4_0*)tmp_buf; + const int ib = i; + + for (int j = 0; j < QK4_0/2; j ++) + { + *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j]; + } + *(d_ptr + ib) = x[ib].d; + }); + + sycl::free(tmp_buf, *stream); +} + +void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) { + char*data_device = (char*)src0->data; + size_t ncols = src0->ne[0]; + size_t nrows = src0->ne[1]; + size_t size = ggml_nbytes(src0); + + reorder_qw(data_device, ncols, nrows, size, 0, stream); +} + +void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) { + ggml_tensor *src0 = dst->src[0]; + ggml_tensor *src1 = dst->src[1]; + + if (dst->op == GGML_OP_MUL_MAT && src0->type == GGML_TYPE_Q4_0 && + src1->ne[2]==1 && src1->ne[3]==1) { + reorder_qw(src0, stream); + ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra; + GGML_ASSERT(extra); + extra->optimized_feature.reorder = true; //used to decode/dequan in next steps. + } +} + +void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) { + dpct::queue_ptr stream = ctx->stream(); + if (ctx->optimized_graph) { + return; + } + ctx->optimized_graph = true; + + for (int i = 0; i < cgraph->n_nodes; i++) { + if (ctx->opt_feature.reorder) opt_for_reorder(cgraph->nodes[i], stream); + } +} static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; ggml_sycl_set_main_device(sycl_ctx->device); + if (!g_ggml_sycl_disable_optimize) optimize_graph_once(cgraph, sycl_ctx); for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; diff --git a/ggml/src/ggml-sycl/sycl_hw.cpp b/ggml/src/ggml-sycl/sycl_hw.cpp new file mode 100644 index 00000000000..da121ffc261 --- /dev/null +++ b/ggml/src/ggml-sycl/sycl_hw.cpp @@ -0,0 +1,13 @@ +#include "sycl_hw.hpp" + + +sycl_hw_info get_device_hw_info(sycl::device *device_ptr) { + sycl_hw_info res; + int32_t id = device_ptr->get_info(); + res.device_id = id; + + syclex::architecture arch = device_ptr->get_info(); + res.arch = arch; + + return res; +} diff --git a/ggml/src/ggml-sycl/sycl_hw.hpp b/ggml/src/ggml-sycl/sycl_hw.hpp new file mode 100644 index 00000000000..bf689450ce6 --- /dev/null +++ b/ggml/src/ggml-sycl/sycl_hw.hpp @@ -0,0 +1,23 @@ +#ifndef SYCL_HW_HPP +#define SYCL_HW_HPP + +#include +#include +#include +#include + +#include + +namespace syclex = sycl::ext::oneapi::experimental; + +struct sycl_hw_info { + syclex::architecture arch; + int32_t device_id; +}; + +bool is_in_vector(std::vector &vec, int item); + +sycl_hw_info get_device_hw_info(sycl::device *device_ptr); + + +#endif // SYCL_HW_HPP From fa016e317528d6f6db2b6a3c03f7600199f618ff Mon Sep 17 00:00:00 2001 From: lhez Date: Mon, 24 Feb 2025 13:47:07 -0800 Subject: [PATCH 55/58] opencl: fix for small models (llama/11950) * opencl: fix small shape gemv, remove unused extensions * opencl: fix `transpose_16`, `dump_tensor`, enforce subgroup size * opencl: fix for token length < 4 * opencl: use wave size of 64 for all Adreno GPUs --------- Co-authored-by: Shawn Gu Co-authored-by: Skyler Szot --- ggml/src/ggml-opencl/ggml-opencl.cpp | 54 +++++++++---------- ggml/src/ggml-opencl/kernels/ggml-opencl.cl | 3 ++ .../kernels/ggml-opencl_gemv_noshuffle.cl | 13 +++-- .../ggml-opencl_gemv_noshuffle_general.cl | 13 +++-- .../kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl | 11 +++- .../kernels/ggml-opencl_transpose_16.cl | 32 +++++------ 6 files changed, 67 insertions(+), 59 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 7a0f94cf24c..f590624608c 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -444,19 +444,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { backend_ctx->gpu_family = GPU_FAMILY::ADRENO; backend_ctx->adreno_gen = get_adreno_gpu_gen(default_device->name); - // Default wave size is 128, A8x uses 64. - if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::A8X) { - backend_ctx->adreno_wave_size = 64; - } else if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::A7X || - backend_ctx->adreno_gen == ADRENO_GPU_GEN::X1E) { - backend_ctx->adreno_wave_size = 128; - } else { - backend_ctx->adreno_wave_size = 128; - GGML_LOG_WARN("ggml_opencl: Unsupported Adreno GPU: %s, " - "using wave size %d, " - "may not work as expected\n", - backend_ctx->device_name.c_str(), backend_ctx->adreno_wave_size); - } + // Use wave size of 64 for all Adreno GPUs. + backend_ctx->adreno_wave_size = 64; } else if (strstr(default_device->name, "Intel")) { backend_ctx->gpu_family = GPU_FAMILY::INTEL; } else { @@ -1376,6 +1365,11 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, int M = tensor->ne[1]; // ne01 int K = tensor->ne[0]; // ne00 + //For matrix-vector multiplication kernel, we assume K is a multiple of 32 + GGML_ASSERT(K % 32 == 0); + //For transpose kernels, we assume K is a multiple of 4 (satisfied by prior assert), and M is a multiple of 4 + GGML_ASSERT(M % 4 == 0); + // transpose is out of place, so we need to allocate transposed buffers // <----------------------------------------------------------------------------------> // // use sub_buffer of max buffer size instead @@ -1416,36 +1410,36 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_mem qT_d_image1D; cl_mem dT_d_image1D; - cl_image_format img_fmt_1d = { CL_RGBA, CL_FLOAT }; + cl_image_format img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; cl_image_desc img_desc_1d; memset(&img_desc_1d, 0, sizeof(img_desc_1d)); img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 8 / 4; + img_desc_1d.image_width = M * K / 4 / 4; img_desc_1d.buffer = extra->q; q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); CL_CHECK(err); - img_fmt_1d = { CL_RGBA, CL_FLOAT }; + img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; memset(&img_desc_1d, 0, sizeof(img_desc_1d)); img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 8 / 4; + img_desc_1d.image_width = M * K / 4 / 4; img_desc_1d.buffer = qT_d; qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); CL_CHECK(err); - img_fmt_1d = { CL_RGBA, CL_FLOAT }; + img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; memset(&img_desc_1d, 0, sizeof(img_desc_1d)); img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 32 / 4 / 2; + img_desc_1d.image_width = M * K / 32 / 4; img_desc_1d.buffer = extra->d; d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); CL_CHECK(err); - img_fmt_1d = { CL_RGBA, CL_FLOAT }; + img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; memset(&img_desc_1d, 0, sizeof(img_desc_1d)); img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 32 / 4 / 2; + img_desc_1d.image_width = M * K / 32 / 4; img_desc_1d.buffer = dT_d; dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); CL_CHECK(err); @@ -1454,8 +1448,8 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, // set up and call the transpose kernels // <----------------------------------------------------------------------------------> // // weights - int height_q = M / 8; - int width_q = K / 8 / 4; + int height_q = M / 4; + int width_q = K / 4 / 4; kernel = backend_ctx->kernel_transpose_16; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D)); @@ -1469,8 +1463,8 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clWaitForEvents(1, &evt)); // scales - int height_s = M / 8; - int width_s = K / 32 / 8; + int height_s = M / 4; + int width_s = K / 32 / 4; kernel = backend_ctx->kernel_transpose_16; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D)); @@ -1864,7 +1858,6 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso void * buf_d; #endif -#ifdef GGML_USE_OPENCL // Make sure everything is done. CL_CHECK(clFinish(queue)); @@ -1900,7 +1893,6 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL)); CL_CHECK(clFinish(queue)); #endif // GGML_OPENCL_SOA_Q -#endif // GGML_USE_OPENCL // Open file and dump. char fname[512]; @@ -2865,6 +2857,9 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(status); int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } int width_B = K/4; int padded_height_B = (N + padding)/4; @@ -3013,11 +3008,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } if (N == 1) { - local_work_size[0] = backend_ctx->adreno_wave_size; // localsize + size_t wavesize = backend_ctx->adreno_wave_size; + local_work_size[0] = wavesize; // localsize local_work_size[1] = 4; // reduce factor local_work_size[2] = 1; - global_work_size[0] = M / 2; + global_work_size[0] = (((M / 2) + wavesize - 1) / wavesize) * wavesize; global_work_size[1] = 4; // reduce factor global_work_size[2] = 1; } diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl index d3cfb2f91e1..8882a8c9c62 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl @@ -1797,6 +1797,9 @@ kernel void kernel_mul_mat_f16_f16( //------------------------------------------------------------------------------ // mul_mat_f16_f32_1row //------------------------------------------------------------------------------ +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif kernel void kernel_mul_mat_f16_f32_1row( global char * src0, ulong offset0, diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl index 5e195411d69..ee5c79f000d 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl @@ -1,9 +1,11 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable #pragma OPENCL EXTENSION cl_khr_subgroups : enable -#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable -#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable -#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#ifdef cl_qcom_reqd_sub_group_size #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif // assume #define QK4_0 32 @@ -186,8 +188,9 @@ total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ - -__attribute__((qcom_reqd_sub_group_size("full"))) +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif __kernel void kernel_gemv_noshuffle( __read_only image1d_buffer_t src0_q, // quantized A global half2 * src0_d, // A scales diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl index 5bdd4d06763..469d3edef00 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl @@ -1,9 +1,11 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable #pragma OPENCL EXTENSION cl_khr_subgroups : enable -#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable -#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable -#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#ifdef cl_qcom_reqd_sub_group_size #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif // assume #define QK4_0 32 @@ -186,8 +188,9 @@ total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ - -__attribute__((qcom_reqd_sub_group_size("full"))) +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif __kernel void kernel_gemv_noshuffle( __read_only image1d_buffer_t src0_q, // quantized A global half2 * src0_d, // A scales diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl index 57768c80334..ecb577b9933 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl @@ -7,7 +7,16 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -__attribute__((qcom_reqd_sub_group_size("full"))) +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + kernel void kernel_mul_mat_Ab_Bi_8x4( global const ushort * src0_q, // quantized A global const half * src0_d, // A scales diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl index d59a0c05ddf..cd4e0afbad2 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl @@ -1,4 +1,6 @@ -// 16-bit transpose, loading/storing an 8x8 tile of elements +// 16-bit transpose, loading/storing a 4x4 tile of elements + +#pragma OPENCL EXTENSION cl_khr_fp16 : enable kernel void kernel_transpose_16( __read_only image1d_buffer_t input, @@ -9,24 +11,16 @@ kernel void kernel_transpose_16( const int i = get_global_id(0); const int j = get_global_id(1); - const int i_3 = i<<3; - const int j_3 = j<<3; + const int i_2 = i<<2; + const int j_2 = j<<2; - ushort8 temp0 = as_ushort8(read_imagef(input, (j_3+0)*cols+i)); - ushort8 temp1 = as_ushort8(read_imagef(input, (j_3+1)*cols+i)); - ushort8 temp2 = as_ushort8(read_imagef(input, (j_3+2)*cols+i)); - ushort8 temp3 = as_ushort8(read_imagef(input, (j_3+3)*cols+i)); - ushort8 temp4 = as_ushort8(read_imagef(input, (j_3+4)*cols+i)); - ushort8 temp5 = as_ushort8(read_imagef(input, (j_3+5)*cols+i)); - ushort8 temp6 = as_ushort8(read_imagef(input, (j_3+6)*cols+i)); - ushort8 temp7 = as_ushort8(read_imagef(input, (j_3+7)*cols+i)); + half4 temp0 = read_imageh(input, (j_2+0)*cols+i); + half4 temp1 = read_imageh(input, (j_2+1)*cols+i); + half4 temp2 = read_imageh(input, (j_2+2)*cols+i); + half4 temp3 = read_imageh(input, (j_2+3)*cols+i); - write_imagef(output, (i_3+0)*rows+j, as_float4((ushort8)(temp0.s0, temp1.s0, temp2.s0, temp3.s0, temp4.s0, temp5.s0, temp6.s0, temp7.s0))); - write_imagef(output, (i_3+1)*rows+j, as_float4((ushort8)(temp0.s1, temp1.s1, temp2.s1, temp3.s1, temp4.s1, temp5.s1, temp6.s1, temp7.s1))); - write_imagef(output, (i_3+2)*rows+j, as_float4((ushort8)(temp0.s2, temp1.s2, temp2.s2, temp3.s2, temp4.s2, temp5.s2, temp6.s2, temp7.s2))); - write_imagef(output, (i_3+3)*rows+j, as_float4((ushort8)(temp0.s3, temp1.s3, temp2.s3, temp3.s3, temp4.s3, temp5.s3, temp6.s3, temp7.s3))); - write_imagef(output, (i_3+4)*rows+j, as_float4((ushort8)(temp0.s4, temp1.s4, temp2.s4, temp3.s4, temp4.s4, temp5.s4, temp6.s4, temp7.s4))); - write_imagef(output, (i_3+5)*rows+j, as_float4((ushort8)(temp0.s5, temp1.s5, temp2.s5, temp3.s5, temp4.s5, temp5.s5, temp6.s5, temp7.s5))); - write_imagef(output, (i_3+6)*rows+j, as_float4((ushort8)(temp0.s6, temp1.s6, temp2.s6, temp3.s6, temp4.s6, temp5.s6, temp6.s6, temp7.s6))); - write_imagef(output, (i_3+7)*rows+j, as_float4((ushort8)(temp0.s7, temp1.s7, temp2.s7, temp3.s7, temp4.s7, temp5.s7, temp6.s7, temp7.s7))); + write_imageh(output, (i_2+0)*rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); + write_imageh(output, (i_2+1)*rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imageh(output, (i_2+2)*rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); } From 351801a1c5544298e621ab6c8a44565e3d9e253e Mon Sep 17 00:00:00 2001 From: Gian-Carlo Pascutto Date: Tue, 25 Feb 2025 10:27:58 +0100 Subject: [PATCH 56/58] metal : copy kernels for quant to F32/F16 conversions (llama/12017) metal: use dequantize_q templates --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-metal/ggml-metal.m | 82 ++++++++++++++++++++++++++-- ggml/src/ggml-metal/ggml-metal.metal | 43 +++++++++++++++ 2 files changed, 120 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 087e7f58149..c550142a7d0 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -407,6 +407,16 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, + GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, + GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, + GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, + GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, + GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, GGML_METAL_KERNEL_TYPE_CONCAT, GGML_METAL_KERNEL_TYPE_SQR, GGML_METAL_KERNEL_TYPE_SQRT, @@ -1012,6 +1022,16 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); @@ -1287,6 +1307,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex default: return false; } + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + default: + return false; + } default: return false; }; @@ -3899,10 +3931,6 @@ static void ggml_metal_encode_node( case GGML_OP_CPY: case GGML_OP_CONT: { - GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - - int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); - id pipeline = nil; switch (src0t) { @@ -3936,7 +3964,47 @@ static void ggml_metal_encode_node( switch (dstt) { case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break; case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break; - default: GGML_ASSERT(false && "not implemented"); + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q4_0: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q4_1: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q5_0: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q5_1: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q8_0: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break; + default: GGML_ABORT("not implemented"); }; } break; default: GGML_ABORT("not implemented"); @@ -3966,7 +4034,11 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_SET: { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 83e7ac9f411..d092a169061 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4341,6 +4341,49 @@ kernel void kernel_cpy_f32_iq4_nl( } } +template +kernel void kernel_cpy_q_f32( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + + device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) { + T4x4 temp; + dequantize_func(src_data + i00/nl, i00%nl, temp); + dst_data[i00] = temp; + } +} + +typedef decltype(kernel_cpy_q_f32) cpy_q_f_t; + +template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; + +template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; + kernel void kernel_concat( constant ggml_metal_kargs_concat & args, device const char * src0, From eaa7e836e6595166323a38c213fa96ec519b142f Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 25 Feb 2025 18:06:34 +0530 Subject: [PATCH 57/58] Support pure float16 add/sub/mul/div operations in the CUDA (and CPU) backend (ggml/1121) * Support float16-to-float16 add/sub/mul/div operations in the CUDA backend * Add fp16 support for add/sub/mul/div on the CPU backend * Add test cases for fp16 add/sub/mul/div --- ggml/src/ggml-cpu/ggml-cpu.c | 213 +++++++++++++++++++++++++++++++-- ggml/src/ggml-cuda/binbcast.cu | 6 +- 2 files changed, 206 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 723253495a7..33ab5e9c6e7 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1415,15 +1415,35 @@ inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } +inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) { + for (int i = 0; i < n; ++i) { + z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) + GGML_FP16_TO_FP32(y[i])); + } +} inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } +inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) { + for (int i = 0; i < n; ++i) { + z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) - GGML_FP16_TO_FP32(y[i])); + } +} inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } +inline static void ggml_vec_mul_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) { + for (int i = 0; i < n; ++i) { + z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) * GGML_FP16_TO_FP32(y[i])); + } +} inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } +inline static void ggml_vec_div_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) { + for (int i = 0; i < n; ++i) { + z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) / GGML_FP16_TO_FP32(y[i])); + } +} static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { assert(nrc == 1); @@ -4379,7 +4399,7 @@ static void ggml_compute_forward_add_f16_f16( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); const int ith = params->ith; const int nth = params->nth; @@ -4404,17 +4424,22 @@ static void ggml_compute_forward_add_f16_f16( if (nb10 == sizeof(ggml_fp16_t)) { for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i])); + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0; r < nr0; ++r) { + ggml_vec_add_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); } } } @@ -5202,6 +5227,62 @@ static void ggml_compute_forward_sub_f32( } } +static void ggml_compute_forward_sub_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(ggml_fp16_t)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0; r < nr0; ++r) { + ggml_vec_sub_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); + } + } + } else { + // src1 is not contiguous + GGML_ABORT("unimplemented error"); + } +} + static void ggml_compute_forward_sub( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -5213,6 +5294,10 @@ static void ggml_compute_forward_sub( { ggml_compute_forward_sub_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_sub_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -5293,6 +5378,55 @@ static void ggml_compute_forward_mul_f32( } } +static void ggml_compute_forward_mul_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + if (nb10 == sizeof(ggml_fp16_t)) { + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0 ; r < nr0; ++r) { + ggml_vec_mul_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); + } + } + } else { + // src1 is not contiguous + GGML_ABORT("unimplemented error"); + } +} + static void ggml_compute_forward_mul( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -5300,13 +5434,17 @@ static void ggml_compute_forward_mul( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now"); + GGML_ASSERT((src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16) && "only f32/f16 src1 supported for now"); switch (src0->type) { case GGML_TYPE_F32: { ggml_compute_forward_mul_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_mul_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); @@ -5387,6 +5525,55 @@ static void ggml_compute_forward_div_f32( } } +static void ggml_compute_forward_div_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + if (nb10 == sizeof(ggml_fp16_t)) { + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0; r < nr0; ++r) { + ggml_vec_div_f16(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); + } + } + } else { + // src1 is not contiguous + GGML_ABORT("unimplemented error"); + } +} + static void ggml_compute_forward_div( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -5398,6 +5585,10 @@ static void ggml_compute_forward_div( { ggml_compute_forward_div_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_div_f16(params, dst); + } break; default: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index ce4b9cfb51b..e1fbf0e1366 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -294,11 +294,13 @@ static void ggml_cuda_op_bin_bcast( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) { - GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream); } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); From ee6ab0bda7a6cc17b82428f66cc5542e0e0d6e01 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Feb 2025 22:39:12 +0200 Subject: [PATCH 58/58] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 26a105f64f6..bf644cd86a5 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -694244a6e40dc255f6bb4376fb17431c06633e6c +738a3aea59f1c0c7751d65307d1228c1dbbf6a84