From 051ff11140c910ab905f9c82b50b772ec6f44e84 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Nov 2024 15:28:55 +0200 Subject: [PATCH 01/20] metal : add kernel arg structs (wip) --- ggml/src/ggml-common.h | 30 ++++++ ggml/src/ggml-metal/ggml-metal.m | 68 +++++++------ ggml/src/ggml-metal/ggml-metal.metal | 143 +++++++++------------------ 3 files changed, 116 insertions(+), 125 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 050161393456e..f3d12df5176d6 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -418,6 +418,36 @@ typedef struct { } block_iq4_xs; static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); +#if defined(GGML_COMMON_DECL_METAL_KARGS) +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t n_past; + int32_t n_dims; + int32_t n_ctx_orig; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; +} ggml_metal_kargs_rope; +#endif + #endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 95b21fbf9c503..7282db0dd0bda 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -3,6 +3,10 @@ #import "ggml-impl.h" #import "ggml-backend-impl.h" +#define GGML_COMMON_DECL_C +#define GGML_COMMON_DECL_METAL_KARGS +#include "ggml-common.h" + #import #import @@ -2706,40 +2710,44 @@ static void ggml_metal_encode_node( }; } + ggml_metal_kargs_rope args = { + .ne00 = ne00, + .ne01 = ne01, + .ne02 = ne02, + .ne03 = ne03, + .nb00 = nb00, + .nb01 = nb01, + .nb02 = nb02, + .nb03 = nb03, + .ne0 = ne0, + .ne1 = ne1, + .ne2 = ne2, + .ne3 = ne3, + .nb0 = nb0, + .nb1 = nb1, + .nb2 = nb2, + .nb3 = nb3, + .n_past = n_past, + .n_dims = n_dims, + .n_ctx_orig = n_ctx_orig, + .freq_base = freq_base, + .freq_scale = freq_scale, + .ext_factor = ext_factor, + .attn_factor = attn_factor, + .beta_fast = beta_fast, + .beta_slow = beta_slow, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; if (id_src2 != nil) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; - [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBytes:&args length:sizeof(args) atIndex:4]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 8c7fcb11303b8..f1c95b1210abe 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1,4 +1,5 @@ #define GGML_COMMON_DECL_METAL +#define GGML_COMMON_DECL_METAL_KARGS #define GGML_COMMON_IMPL_METAL #if defined(GGML_METAL_EMBED_LIBRARY) __embed_ggml-common.h__ @@ -2234,7 +2235,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. static void rope_yarn( - float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale, thread float * cos_theta, thread float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; @@ -2266,65 +2267,41 @@ static void rope_yarn_corr_dims( template kernel void kernel_rope_norm( - device const void * src0, - device const int32_t * src1, - device const float * src2, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & n_ctx_orig, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + constant ggml_metal_kargs_rope & args, uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], + uint3 tptg [[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { - const int64_t i3 = tgpig[2]; - const int64_t i2 = tgpig[1]; - const int64_t i1 = tgpig[0]; + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; float corr_dims[2]; - rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); - device const int32_t * pos = src1; + device const int32_t * pos = (device const int32_t *) src1; const float theta_base = (float) pos[i2]; - const float inv_ndims = -1.f/n_dims; + const float inv_ndims = -1.f/args.n_dims; float cos_theta; float sin_theta; - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - if (i0 < n_dims) { - const int64_t ic = i0/2; + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; - const float theta = theta_base * pow(freq_base, inv_ndims*i0); + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; - rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); const float x0 = src[0]; const float x1 = src[1]; @@ -2332,8 +2309,8 @@ kernel void kernel_rope_norm( dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[1] = x0*sin_theta + x1*cos_theta; } else { - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); dst_data[0] = src[0]; dst_data[1] = src[1]; @@ -2343,74 +2320,50 @@ kernel void kernel_rope_norm( template kernel void kernel_rope_neox( - device const void * src0, - device const int32_t * src1, - device const float * src2, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & n_ctx_orig, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + constant ggml_metal_kargs_rope & args, uint tiitg[[thread_index_in_threadgroup]], uint3 tptg[[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { - const int64_t i3 = tgpig[2]; - const int64_t i2 = tgpig[1]; - const int64_t i1 = tgpig[0]; + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; float corr_dims[2]; - rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); - device const int32_t * pos = src1; + device const int32_t * pos = (device const int32_t *) src1; const float theta_base = (float) pos[i2]; - const float inv_ndims = -1.f/n_dims; + const float inv_ndims = -1.f/args.n_dims; float cos_theta; float sin_theta; - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - if (i0 < n_dims) { - const int64_t ic = i0/2; + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; - const float theta = theta_base * pow(freq_base, inv_ndims*i0); + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; - rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); const float x0 = src[0]; - const float x1 = src[n_dims/2]; + const float x1 = src[args.n_dims/2]; - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; } else { - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); dst_data[0] = src[0]; dst_data[1] = src[1]; From 362a3f3433fdb3597768752bf5ac6bc60307f254 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Nov 2024 16:09:31 +0200 Subject: [PATCH 02/20] metal : fattn args ggml-ci --- ggml/src/ggml-common.h | 24 +++++ ggml/src/ggml-metal/ggml-metal.m | 58 +++++----- ggml/src/ggml-metal/ggml-metal.metal | 156 ++++++++++----------------- 3 files changed, 113 insertions(+), 125 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index f3d12df5176d6..6529d71ebeb7f 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -446,6 +446,30 @@ typedef struct { float beta_fast; float beta_slow; } ggml_metal_kargs_rope; + +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb_12_1; + uint64_t nb_12_2; + uint64_t nb_12_3; + uint64_t nb31; + int32_t ne1; + int32_t ne2; + float scale; + float max_bias; + float m0; + float m1; + uint16_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext; #endif #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 7282db0dd0bda..8ee9c67916537 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -3232,37 +3232,41 @@ static void ggml_metal_encode_node( } } + ggml_metal_kargs_flash_attn_ext args = { + .ne01 = ne01, + .ne02 = ne02, + .ne03 = ne03, + .nb01 = nb01, + .nb02 = nb02, + .nb03 = nb03, + .ne11 = ne11, + .ne_12_2 = ne12, + .ne_12_3 = ne13, + .nb_12_1 = nb11, + .nb_12_2 = nb12, + .nb_12_3 = nb13, + .nb31 = nb31, + .ne1 = ne1, + .ne2 = ne2, + .scale = scale, + .max_bias = max_bias, + .m0 = m0, + .m1 = m1, + .n_head_log2 = n_head_log2, + .logit_softcap = logit_softcap, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19]; - [encoder setBytes:&scale length:sizeof( float) atIndex:20]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:21]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:22]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:23]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24]; - [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&args length:sizeof(args) atIndex:5]; if (!use_vec_kernel) { // half8x8 kernel diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f1c95b1210abe..24c714eaebf7e 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2272,9 +2272,9 @@ kernel void kernel_rope_norm( device const char * src2, device char * dst, constant ggml_metal_kargs_rope & args, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]) { + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { const int i3 = tgpig[2]; const int i2 = tgpig[1]; const int i1 = tgpig[0]; @@ -2325,9 +2325,9 @@ kernel void kernel_rope_neox( device const char * src2, device char * dst, constant ggml_metal_kargs_rope & args, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]) { + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { const int i3 = tgpig[2]; const int i2 = tgpig[1]; const int i1 = tgpig[0]; @@ -2766,32 +2766,12 @@ kernel void kernel_flash_attn_ext( device const char * v, device const char * mask, device float * dst, - constant int32_t & ne01, - constant int32_t & ne02, - constant int32_t & ne03, - constant uint32_t & nb01, - constant uint32_t & nb02, - constant uint32_t & nb03, - constant int32_t & ne11, - constant int32_t & ne_12_2, // assume K and V are same shape - constant int32_t & ne_12_3, - constant uint32_t & nb_12_1, - constant uint32_t & nb_12_2, - constant uint32_t & nb_12_3, - constant uint32_t & nb31, - constant int32_t & ne1, - constant int32_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint16_t & n_head_log2, - constant float & logit_softcap, + constant ggml_metal_kargs_flash_attn_ext & args, threadgroup half * shared [[threadgroup(0)]], - ushort3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups const int iq3 = tgpig[2]; @@ -2824,10 +2804,10 @@ kernel void kernel_flash_attn_ext( // load heads from Q to shared memory for (short j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); for (short i = tiisg; i < D4; i += NW) { - if (iq1 + j < ne01) { + if (iq1 + j < args.ne01) { sq4[j*D4 + i] = (q4_t) q4[i]; } else { sq4[j*D4 + i] = (q4_t) 0.0f; @@ -2860,11 +2840,11 @@ kernel void kernel_flash_attn_ext( const short ty = tiisg/4; // broadcast kv - //const short rk2 = ne02/ne12; - //const short rk3 = ne03/ne13; + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; - const short ikv2 = iq2/(ne02/ne_12_2); - const short ikv3 = iq3/(ne03/ne_12_3); + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); // load the queries from shared memory into local memory q8x8_t mq[D8]; @@ -2878,20 +2858,20 @@ kernel void kernel_flash_attn_ext( half slope = 1.0f; // ALiBi - if (max_bias > 0.0f) { + if (args.max_bias > 0.0f) { const short h = iq2; - const half base = h < n_head_log2 ? m0 : m1; - const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const half base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exph); } // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { const int ic = ic0 + C*sgitg; - if (ic >= ne11) { + if (ic >= args.ne11) { break; } @@ -2902,7 +2882,7 @@ kernel void kernel_flash_attn_ext( // load the mask in shared memory #pragma unroll(Q) for (short j = 0; j < Q; ++j) { - device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31); + device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); const half m = pm[ic + tiisg]; @@ -2925,18 +2905,18 @@ kernel void kernel_flash_attn_ext( // this is compile-time check, so it does not have runtime overhead if (is_same::value) { // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); #pragma unroll(D8) for (short i = 0; i < D8; ++i) { k8x8_t mk; - simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10 + simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10 simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } } else { for (short ii = 0; ii < D16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); if (D16%4 == 0) { // the head is evenly divisible by 4*16 = 64, so no need for bound checks @@ -2995,10 +2975,10 @@ kernel void kernel_flash_attn_ext( const half m = M[j]; // scale and apply the logitcap / mask - half s = ss[j*TS + tiisg]*scale; + half s = ss[j*TS + tiisg]*args.scale; - if (logit_softcap != 0.0f) { - s = logit_softcap*precise::tanh(s); + if (args.logit_softcap != 0.0f) { + s = args.logit_softcap*precise::tanh(s); } // mqk = mqk + mask*slope @@ -3040,18 +3020,18 @@ kernel void kernel_flash_attn_ext( if (is_same::value) { // we can read directly from global memory - device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); #pragma unroll(D8) for (short i = 0; i < D8; ++i) { v8x8_t mv; - simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20 + simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20 simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); } } else { for (short ii = 0; ii < D16; ii += 4) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); if (D16%4 == 0) { // no need for bound checks @@ -3180,11 +3160,11 @@ kernel void kernel_flash_attn_ext( // final rescale with 1/S and store to global memory if (sgitg == 0) { - for (short j = 0; j < Q && iq1 + j < ne01; ++j) { + for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) { const float S = ss[j*TS + 0]; for (short i = tiisg; i < D4; i += NW) { - dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; + dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; } } } @@ -3281,33 +3261,13 @@ kernel void kernel_flash_attn_ext_vec( device const char * v, device const char * mask, device float * dst, - constant int32_t & ne01, - constant int32_t & ne02, - constant int32_t & ne03, - constant uint32_t & nb01, - constant uint32_t & nb02, - constant uint32_t & nb03, - constant int32_t & ne11, - constant int32_t & ne_12_2, // assume K and V are same shape - constant int32_t & ne_12_3, - constant uint32_t & nb_12_1, - constant uint32_t & nb_12_2, - constant uint32_t & nb_12_3, - constant uint32_t & nb31, - constant int32_t & ne1, - constant int32_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint16_t & n_head_log2, - constant float & logit_softcap, + constant ggml_metal_kargs_flash_attn_ext & args, threadgroup half * shared [[threadgroup(0)]], - ushort3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups const int iq3 = tgpig[2]; @@ -3334,10 +3294,10 @@ kernel void kernel_flash_attn_ext_vec( o4x4_t lo[D16/NL]; // load heads from Q to shared memory - device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); for (short i = tiisg; i < D4; i += NW) { - if (iq1 < ne01) { + if (iq1 < args.ne01) { sq4[i] = (q4_t) q4[i]; } else { sq4[i] = (q4_t) 0.0f; @@ -3365,11 +3325,11 @@ kernel void kernel_flash_attn_ext_vec( const short ty = tiisg/NL; // broadcast kv - //const short rk2 = ne02/ne12; - //const short rk3 = ne03/ne13; + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; - const short ikv2 = iq2/(ne02/ne_12_2); - const short ikv3 = iq3/(ne03/ne_12_3); + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); // load the queries from shared memory into local memory q4x4_t mq[D16/NL]; @@ -3382,25 +3342,25 @@ kernel void kernel_flash_attn_ext_vec( const bool has_mask = mask != q; // pointer to the mask - device const half * pm = (device const half *) (mask + iq1*nb31); + device const half * pm = (device const half *) (mask + iq1*args.nb31); half slope = 1.0f; // ALiBi - if (max_bias > 0.0f) { + if (args.max_bias > 0.0f) { const short h = iq2; - const half base = h < n_head_log2 ? m0 : m1; - const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const half base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exph); } // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { const int ic = ic0 + C*sgitg; - if (ic >= ne11) { + if (ic >= args.ne11) { break; } @@ -3414,7 +3374,7 @@ kernel void kernel_flash_attn_ext_vec( for (short cc = 0; cc < C/4; ++cc) { qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 }; - device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); #pragma unroll(D16/NL) for (short ii = 0; ii < D16; ii += NL) { @@ -3450,10 +3410,10 @@ kernel void kernel_flash_attn_ext_vec( // mqk = mqk*scale + mask*slope if (tx == 0) { - mqk *= scale; + mqk *= args.scale; - if (logit_softcap != 0.0f) { - mqk = logit_softcap*precise::tanh(mqk); + if (args.logit_softcap != 0.0f) { + mqk = args.logit_softcap*precise::tanh(mqk); } mqk += sm[4*cc + ty]*slope; @@ -3492,7 +3452,7 @@ kernel void kernel_flash_attn_ext_vec( // O = O + (Q*K^T)*V { for (short cc = 0; cc < C/4; ++cc) { - device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); const s4x4_t ms(ss[4*cc + ty]); @@ -3597,7 +3557,7 @@ kernel void kernel_flash_attn_ext_vec( const float S = ss[0]; for (short i = tiisg; i < D16; i += NW) { - dst44[((int64_t)iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S; + dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1)*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; } } } From cbae08872133dfaedee976997c9305dfc8903e96 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Nov 2024 16:39:36 +0200 Subject: [PATCH 03/20] metal : cont + avoid potential int overflow [no ci] --- ggml/src/ggml-metal/ggml-metal.metal | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 24c714eaebf7e..0ee5d0c445029 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2761,11 +2761,11 @@ template< short KV = 8, // key/value processed per each simdgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, constant ggml_metal_kargs_flash_attn_ext & args, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -3164,7 +3164,7 @@ kernel void kernel_flash_attn_ext( const float S = ss[j*TS + 0]; for (short i = tiisg; i < D4; i += NW) { - dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; + dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; } } } @@ -3256,11 +3256,11 @@ template< short Q = 1, // queries per threadgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, constant ggml_metal_kargs_flash_attn_ext & args, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -3557,7 +3557,7 @@ kernel void kernel_flash_attn_ext_vec( const float S = ss[0]; for (short i = tiisg; i < D16; i += NW) { - dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1)*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; + dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; } } } From 0d0c54fc5a1da4bdb64eb4d6a1bb4f488e1b3f5e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Nov 2024 17:54:40 +0200 Subject: [PATCH 04/20] metal : mul mat struct (wip) --- ggml/src/ggml-common.h | 17 ++++ ggml/src/ggml-metal/ggml-metal.m | 125 ++++++++++++++------------- ggml/src/ggml-metal/ggml-metal.metal | 73 +++++++--------- 3 files changed, 113 insertions(+), 102 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 6529d71ebeb7f..b03e1a6bf8250 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -470,6 +470,23 @@ typedef struct { uint16_t n_head_log2; float logit_softcap; } ggml_metal_kargs_flash_attn_ext; + +typedef struct { + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mm; #endif #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 8ee9c67916537..846f83e9714e2 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1963,24 +1963,29 @@ static void ggml_metal_encode_node( default: GGML_ABORT("MUL MAT-MAT not implemented"); } + ggml_metal_kargs_mul_mm args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:15]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:16]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { @@ -2711,31 +2716,31 @@ static void ggml_metal_encode_node( } ggml_metal_kargs_rope args = { - .ne00 = ne00, - .ne01 = ne01, - .ne02 = ne02, - .ne03 = ne03, - .nb00 = nb00, - .nb01 = nb01, - .nb02 = nb02, - .nb03 = nb03, - .ne0 = ne0, - .ne1 = ne1, - .ne2 = ne2, - .ne3 = ne3, - .nb0 = nb0, - .nb1 = nb1, - .nb2 = nb2, - .nb3 = nb3, - .n_past = n_past, - .n_dims = n_dims, - .n_ctx_orig = n_ctx_orig, - .freq_base = freq_base, - .freq_scale = freq_scale, - .ext_factor = ext_factor, - .attn_factor = attn_factor, - .beta_fast = beta_fast, - .beta_slow = beta_slow, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.n_past =*/ n_past, + /*.n_dims =*/ n_dims, + /*.n_ctx_orig =*/ n_ctx_orig, + /*.freq_base =*/ freq_base, + /*.freq_scale =*/ freq_scale, + /*.ext_factor =*/ ext_factor, + /*.attn_factor =*/ attn_factor, + /*.beta_fast =*/ beta_fast, + /*.beta_slow =*/ beta_slow, }; [encoder setComputePipelineState:pipeline]; @@ -3233,27 +3238,27 @@ static void ggml_metal_encode_node( } ggml_metal_kargs_flash_attn_ext args = { - .ne01 = ne01, - .ne02 = ne02, - .ne03 = ne03, - .nb01 = nb01, - .nb02 = nb02, - .nb03 = nb03, - .ne11 = ne11, - .ne_12_2 = ne12, - .ne_12_3 = ne13, - .nb_12_1 = nb11, - .nb_12_2 = nb12, - .nb_12_3 = nb13, - .nb31 = nb31, - .ne1 = ne1, - .ne2 = ne2, - .scale = scale, - .max_bias = max_bias, - .m0 = m0, - .m1 = m1, - .n_head_log2 = n_head_log2, - .logit_softcap = logit_softcap, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.nb_12_1 =*/ nb11, + /*.nb_12_2 =*/ nb12, + /*.nb_12_3 =*/ nb13, + /*.nb31 =*/ nb31, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, }; [encoder setComputePipelineState:pipeline]; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 0ee5d0c445029..df4e5f77e4075 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3264,7 +3264,6 @@ kernel void kernel_flash_attn_ext_vec( constant ggml_metal_kargs_flash_attn_ext & args, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { @@ -6215,38 +6214,26 @@ kernel void kernel_get_rows_i32( // each block_q contains 16*nl weights template -kernel void kernel_mul_mm(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { +kernel void kernel_mul_mm( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_mul_mm & args, + threadgroup char * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { threadgroup T * sa = (threadgroup T *)(shared_memory); threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; - const uint im = tgpig.z; + const int r0 = tgpig.y; + const int r1 = tgpig.x; + const int im = tgpig.z; // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; + short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; // a thread shouldn't load data outside of the matrix short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; @@ -6262,20 +6249,20 @@ kernel void kernel_mul_mm(device const uchar * src0, short il = (tiitg % THREAD_PER_ROW); - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const int i12 = im%args.ne12; + const int i13 = im/args.ne12; - uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03; - ushort offset1 = il/nl; + int offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + short offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*nb01 + offset0) + offset1; + device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1; device const float * y = (device const float *)(src1 - + nb13 * i13 - + nb12 * i12 - + nb11 * (r1 * BLOCK_SIZE_N + thread_col) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + args.nb13*i13 + + args.nb12*i12 + + args.nb11*(r1 * BLOCK_SIZE_N + thread_col) + + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory T4x4 temp_a; dequantize_func(x, il, temp_a); @@ -6322,11 +6309,13 @@ kernel void kernel_mul_mm(device const uchar * src0, } } - if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { - device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ - + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; + if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) { + device float * C = (device float *) dst + + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) + \ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; + for (short i = 0; i < 8; i++) { - simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); + simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0); } } else { // block is smaller than 64x32, we should avoid writing data outside of the matrix @@ -6341,7 +6330,7 @@ kernel void kernel_mul_mm(device const uchar * src0, if (sgitg == 0) { for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0; + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0; device float4 * D4 = (device float4 *) D; threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); From 07bc7610ad09488658d07ce21d0de2a44ffddce2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Nov 2024 22:56:39 +0200 Subject: [PATCH 05/20] cont : mul mat vec --- ggml/src/ggml-common.h | 43 + ggml/src/ggml-metal/ggml-metal.m | 82 +- ggml/src/ggml-metal/ggml-metal.metal | 1346 ++++++-------------------- 3 files changed, 402 insertions(+), 1069 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index b03e1a6bf8250..fcf0d997c52af 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -487,6 +487,49 @@ typedef struct { int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mm; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mv; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; + uint64_t nb1; +} ggml_metal_kargs_mul_mv_id; #endif #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 846f83e9714e2..92506e8772d9a 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2163,28 +2163,32 @@ static void ggml_metal_encode_node( } }; + ggml_metal_kargs_mul_mv args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:19]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:20]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || @@ -2476,30 +2480,34 @@ static void ggml_metal_encode_node( GGML_ASSERT(ne00 >= nth0*nth1); } + ggml_metal_kargs_mul_mv_id args = { + /*.nei0 =*/ ne20, + /*.nei1 =*/ ne21, + /*.nbi1 =*/ nb21, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.nb1 =*/ nb1, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; + [encoder setBytes:&args length:sizeof(args) atIndex:4]; const int64_t _ne1 = 1; const int tgz = dst_rows; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index df4e5f77e4075..81f1aeedcbceb 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1634,26 +1634,12 @@ void mul_vec_q_n_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - const int nb = ne00/QK4_0; + uint3 tgpig, + uint tiisg, + uint sgitg) { + const int nb = args.ne00/QK4_0; const int r0 = tgpig.x; const int r1 = tgpig.y; @@ -1661,11 +1647,11 @@ void mul_vec_q_n_f32_impl( const int first_row = (r0 * nsg + sgitg) * nr; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + //const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -1673,7 +1659,7 @@ void mul_vec_q_n_f32_impl( // pointers to src0 rows device const block_q_type * ax[nr]; for (int row = 0; row < nr; ++row) { - const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } @@ -1711,8 +1697,8 @@ void mul_vec_q_n_f32_impl( for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; + if (tiisg == 0 && first_row + row < args.ne01) { + dst[im*args.ne0*args.ne1 + r1*args.ne0 + first_row + row] = tot; } } } @@ -1721,136 +1707,53 @@ kernel void kernel_mul_mv_q4_0_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } - #define NB_Q8_0 8 void kernel_mul_mv_q8_0_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1859,18 +1762,18 @@ void kernel_mul_mv_q8_0_f32_impl( const int nsg = N_SIMDGROUP; const int nw = N_SIMDWIDTH; - const int nb = ne00/QK8_0; + const int nb = args.ne00/QK8_0; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * nsg + sgitg) * nr; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + //const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -1878,7 +1781,7 @@ void kernel_mul_mv_q8_0_f32_impl( // pointers to src0 rows device const block_q8_0 * ax[nr]; for (int row = 0; row < nr; ++row) { - const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } @@ -1911,8 +1814,8 @@ void kernel_mul_mv_q8_0_f32_impl( for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + if (tiisg == 0 && first_row + row < args.ne01) { + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = tot; } } } @@ -1922,28 +1825,11 @@ kernel void kernel_mul_mv_q8_0_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + kernel_mul_mv_q8_0_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } #define N_MV_T_T 4 @@ -1953,80 +1839,63 @@ void kernel_mul_mv_impl( device const char * src0, device const char * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg) { + ggml_metal_kargs_mul_mv args, + uint3 tgpig, + uint tiisg) { const int64_t r0 = tgpig.x; const int64_t rb = tgpig.y*N_MV_T_T; const int64_t im = tgpig.z; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; device const T0 * x = (device const T0 *) (src0 + offset0); - if (ne00 < 128) { + if (args.ne00 < 128) { for (int row = 0; row < N_MV_T_T; ++row) { int r1 = rb + row; - if (r1 >= ne11) { + if (r1 >= args.ne11) { break; } - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { + for (int i = tiisg; i < args.ne00; i += 32) { sumf += (T0) x[i] * (T1) y[i]; } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; } } } else { device const T04 * x4 = (device const T04 *) x; for (int row = 0; row < N_MV_T_T; ++row) { int r1 = rb + row; - if (r1 >= ne11) { + if (r1 >= args.ne11) { break; } - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); device const T14 * y4 = (device const T14 *) y; float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { + for (int i = tiisg; i < args.ne00/4; i += 32) { for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); + dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; } } } @@ -2037,48 +1906,14 @@ kernel void kernel_mul_mv( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_impl( src0, src1, dst, - ne00, - ne01, - ne02, - nb00, - nb01, - nb02, - nb03, - ne10, - ne11, - ne12, - nb10, - nb11, - nb12, - nb13, - ne0, - ne1, - r2, - r3, + args, tgpig, tiisg); } @@ -2098,24 +1933,7 @@ kernel void kernel_mul_mv_1row( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -2123,37 +1941,37 @@ kernel void kernel_mul_mv_1row( const int64_t r1 = tgpig.y; const int64_t im = tgpig.z; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T * x = (device const T *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); float sumf = 0; - if (ne00 < 128) { - for (int i = tiisg; i < ne00; i += 32) { + if (args.ne00 < 128) { + for (int i = tiisg; i < args.ne00; i += 32) { sumf += (float) x[i] * (float) y[i]; } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; } } else { device const T4 * x4 = (device const T4 *) x; device const float4 * y4 = (device const float4 *) y; - for (int i = tiisg; i < ne00/4; i += 32) { + for (int i = tiisg; i < args.ne00/4; i += 32) { for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); + dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; } } } @@ -2171,51 +1989,34 @@ kernel void kernel_mul_mv_l4( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { - const int nrows = ne11; + const int nrows = args.ne11; const int64_t r0 = tgpig.x; const int64_t im = tgpig.z; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; device const T4 * x4 = (device const T4 *) (src0 + offset0); for (int r1 = 0; r1 < nrows; ++r1) { - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const float4 * y4 = (device const float4 *) (src1 + offset1); float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { + for (int i = tiisg; i < args.ne00/4; i += 32) { for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; } } } @@ -4134,38 +3935,24 @@ void kernel_mul_mv_q2_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -4217,9 +4004,9 @@ void kernel_mul_mv_q2_K_f32_impl( (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); - qs += nb01/2; - sc += nb01; - dh += nb01/2; + qs += args.nb01/2; + sc += args.nb01; + dh += args.nb01/2; } y4 += 4 * QK_K; @@ -4228,7 +4015,7 @@ void kernel_mul_mv_q2_K_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; } } } @@ -4238,56 +4025,25 @@ kernel void kernel_mul_mv_q2_K_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q3_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -4295,11 +4051,11 @@ void kernel_mul_mv_q3_K_f32_impl( const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0); device const float * yy = (device const float *) ((device char *) src1 + offset1); @@ -4403,10 +4159,10 @@ void kernel_mul_mv_q3_K_f32_impl( sumf1[row] += d1 * (scales[1] - 32); sumf2[row] += d2 * (scales[3] - 32); - q += nb01/2; - h += nb01/2; - a += nb01/2; - dh += nb01/2; + q += args.nb01/2; + h += args.nb01/2; + a += args.nb01/2; + dh += args.nb01/2; } y1 += 4 * QK_K; @@ -4418,7 +4174,7 @@ void kernel_mul_mv_q3_K_f32_impl( } if (tiisg == 0) { for (int row = 0; row < 2; ++row) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row]; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = sumf1[row]; } } } @@ -4428,54 +4184,23 @@ kernel void kernel_mul_mv_q3_K_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q4_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; @@ -4486,18 +4211,18 @@ void kernel_mul_mv_q4_K_f32_impl( const int iq = it/4; // 0 or 1 const int ir = it%4; // 0...3 - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int first_row = r0 * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -4553,9 +4278,9 @@ void kernel_mul_mv_q4_K_f32_impl( (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - q1 += nb01/2; - sc += nb01/2; - dh += nb01/2; + q1 += args.nb01/2; + sc += args.nb01/2; + dh += args.nb01/2; } y4 += 4 * QK_K; @@ -4564,7 +4289,7 @@ void kernel_mul_mv_q4_K_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; } } } @@ -4574,56 +4299,25 @@ kernel void kernel_mul_mv_q4_K_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q5_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -4631,11 +4325,11 @@ void kernel_mul_mv_q5_K_f32_impl( const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0); device const float * yy = (device const float *) ((device char *) src1 + offset1); @@ -4712,10 +4406,10 @@ void kernel_mul_mv_q5_K_f32_impl( sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - q1 += nb01; - qh += nb01; - dh += nb01/2; - a += nb01/2; + q1 += args.nb01; + qh += args.nb01; + dh += args.nb01/2; + a += args.nb01/2; } y1 += 4 * QK_K; @@ -4724,7 +4418,7 @@ void kernel_mul_mv_q5_K_f32_impl( for (int row = 0; row < 2; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = tot; } } } @@ -4734,61 +4428,30 @@ kernel void kernel_mul_mv_q5_K_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q6_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { const uint8_t kmask1 = 0x03; const uint8_t kmask2 = 0x0C; const uint8_t kmask3 = 0x30; const uint8_t kmask4 = 0xC0; - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -4796,11 +4459,11 @@ void kernel_mul_mv_q6_K_f32_impl( const int row = 2 * r0 + sgitg; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0); device const float * yy = (device const float *) ((device char *) src1 + offset1); @@ -4844,7 +4507,7 @@ void kernel_mul_mv_q6_K_f32_impl( const float tot = simd_sum(sumf); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + row] = tot; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + row] = tot; } } @@ -4853,29 +4516,12 @@ kernel void kernel_mul_mv_q6_K_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit @@ -4884,38 +4530,24 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -4971,8 +4603,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl( } sumf[row] += d * sum; - dh += nb01/2; - q2 += nb01/2; + dh += args.nb01/2; + q2 += args.nb01/2; } y4 += 32 * 32; @@ -4981,7 +4613,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f; } } } @@ -4991,68 +4623,37 @@ kernel void kernel_mul_mv_iq2_xxs_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq2_xs_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -5117,9 +4718,9 @@ void kernel_mul_mv_iq2_xs_f32_impl( } sumf[row] += d1 * sum1 + d2 * sum2; - dh += nb01/2; - q2 += nb01/2; - sc += nb01; + dh += args.nb01/2; + q2 += args.nb01/2; + sc += args.nb01; } y4 += 32 * 32; @@ -5128,7 +4729,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f; } } } @@ -5138,68 +4739,37 @@ kernel void kernel_mul_mv_iq2_xs_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq3_xxs_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -5257,9 +4827,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl( } sumf[row] += d * (sum[0] + sum[1]); - dh += nb01/2; - q3 += nb01; - gas += nb01/2; + dh += args.nb01/2; + q3 += args.nb01; + gas += args.nb01/2; } y4 += 32 * 32; @@ -5268,7 +4838,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.5f; } } } @@ -5278,68 +4848,37 @@ kernel void kernel_mul_mv_iq3_xxs_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq3_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -5395,11 +4934,11 @@ void kernel_mul_mv_iq3_s_f32_impl( } sumf[row] += d * (sum[0] + sum[1]); - dh += nb01/2; - qs += nb01; - qh += nb01; - sc += nb01; - signs += nb01; + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01; + sc += args.nb01; + signs += args.nb01; } y4 += 32 * 32; @@ -5408,7 +4947,7 @@ void kernel_mul_mv_iq3_s_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; } } } @@ -5418,68 +4957,37 @@ kernel void kernel_mul_mv_iq3_s_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq2_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -5536,11 +5044,11 @@ void kernel_mul_mv_iq2_s_f32_impl( } sumf[row] += d1 * sum[0] + d2 * sum[1]; - dh += nb01/2; - qs += nb01; - qh += nb01; - sc += nb01; - signs += nb01; + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01; + sc += args.nb01; + signs += args.nb01; } y4 += 32 * 32; @@ -5549,7 +5057,7 @@ void kernel_mul_mv_iq2_s_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f; } } } @@ -5559,68 +5067,37 @@ kernel void kernel_mul_mv_iq2_s_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq1_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_value, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -5666,9 +5143,9 @@ void kernel_mul_mv_iq1_s_f32_impl( } sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); - dh += nb01/2; - qs += nb01; - qh += nb01/2; + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01/2; } y4 += 32 * 32; @@ -5677,7 +5154,7 @@ void kernel_mul_mv_iq1_s_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; } } } @@ -5686,38 +5163,24 @@ void kernel_mul_mv_iq1_m_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_value, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -5772,9 +5235,9 @@ void kernel_mul_mv_iq1_m_f32_impl( sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); - sc += nb01/2; - qs += nb01; - qh += nb01; + sc += args.nb01/2; + qs += args.nb01; + qh += args.nb01; } y4 += 32 * 32; @@ -5783,7 +5246,7 @@ void kernel_mul_mv_iq1_m_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; } } } @@ -5792,38 +5255,24 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values_i8, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { threadgroup float * shared_values = (threadgroup float *)shared_values_i8; - const int nb = ne00/QK4_NL; + const int nb = args.ne00/QK4_NL; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * 2 + sgitg) * 2; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -5849,7 +5298,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (int row = 0; row < 2 && first_row + row < ne01; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { device const block_iq4_nl & xb = x[row*nb + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); @@ -5881,10 +5330,10 @@ void kernel_mul_mv_iq4_nl_f32_impl( yb += 16 * QK4_NL; } - for (int row = 0; row < 2 && first_row + row < ne01; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; } } } @@ -5893,38 +5342,24 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values_i8, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { threadgroup float * shared_values = (threadgroup float *)shared_values_i8; - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * 2 + sgitg) * 2; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -5986,7 +5421,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( for (int row = 0; row < 2; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; } } } @@ -5996,29 +5431,12 @@ kernel void kernel_mul_mv_iq1_s_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq1_m_f32")]] @@ -6026,29 +5444,12 @@ kernel void kernel_mul_mv_iq1_m_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_nl_f32")]] @@ -6056,30 +5457,13 @@ kernel void kernel_mul_mv_iq4_nl_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_xs_f32")]] @@ -6087,30 +5471,13 @@ kernel void kernel_mul_mv_iq4_xs_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } template @@ -6653,46 +6020,15 @@ typedef void (kernel_mul_mv_impl_t)( device const char * src0, device const char * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg); + ggml_metal_kargs_mul_mv args, + uint3 tgpig, + uint tiisg); typedef void (kernel_mul_mv2_impl_t)( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -6703,32 +6039,13 @@ void mmv_fn( device const char * src0, device const char * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne11, - int64_t ne12, - int64_t ne13, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint64_t nb1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiitg, uint tiisg, uint sgitg) { - impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,nb03,ne10,ne11,ne12,nb10,nb11,nb12,nb13,ne0,ne1,r2,r3,tgpig,tiisg); + impl_fn(src0, src1, dst, args, tgpig, tiisg); } template @@ -6736,32 +6053,13 @@ void mmv_fn( device const char * src0, device const char * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne11, - int64_t ne12, - int64_t ne13, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint64_t nb1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiitg, uint tiisg, uint sgitg) { - impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); + impl_fn(src0,(const device float *) src1, dst, args, shared_values, tgpig, tiisg, sgitg); } typedef decltype(mmv_fn>) mul_mv_impl_fn_t; @@ -6772,71 +6070,55 @@ kernel void kernel_mul_mv_id( device const char * src1, device float * dst, device const char * ids, - constant int64_t & nei0, - constant int64_t & nei1, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, + constant ggml_metal_kargs_mul_mv_id & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int iid1 = tgpig.z/nei0; - const int idx = tgpig.z%nei0; + const int iid1 = tgpig.z/args.nei0; + const int idx = tgpig.z%args.nei0; tgpig.z = 0; - const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx]; + const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx]; - const int64_t i11 = idx % ne11; + const int64_t i11 = idx % args.ne11; const int64_t i12 = iid1; const int64_t i1 = idx; const int64_t i2 = i12; - device const char * src0_cur = src0s + i02*nb02; - device const char * src1_cur = src1 + i11*nb11 + i12*nb12; - device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; + device const char * src0_cur = src0s + i02*args.nb02; + device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12; + device float * dst_cur = dst + i1*args.ne0 + i2*args.ne1*args.ne0; + + ggml_metal_kargs_mul_mv args0 = { + /*.ne00 =*/ args.ne00, + /*.ne01 =*/ args.ne01, + /*.ne02 =*/ 1, // args.ne02, + /*.nb00 =*/ args.nb00, + /*.nb01 =*/ args.nb01, + /*.nb02 =*/ args.nb02, + /*.nb03 =*/ args.nb02, // args.ne02 == 1 + /*.ne10 =*/ args.ne10, + /*.ne11 =*/ 1, // args.ne11, + /*.ne12 =*/ 1, // args.ne12, + /*.nb10 =*/ args.nb10, + /*.nb11 =*/ args.nb11, + /*.nb12 =*/ args.nb12, + /*.nb13 =*/ args.nb12, // ne12 == 1 + /*.ne0 =*/ args.ne0, + /*.ne1 =*/ 1, // args.ne1, + /*.r2 =*/ 1, + /*.r3 =*/ 1, + }; impl_fn( /* src0 */ src0_cur, /* src1 */ src1_cur, /* dst */ dst_cur, - /* ne00 */ ne00, - /* ne01 */ ne01, - /* ne02 */ 1, // ne02, - /* nb00 */ nb00, - /* nb01 */ nb01, - /* nb02 */ nb02, - /* nb03 */ nb02, // ne02 == 1 - /* ne10 */ ne10, - /* ne11 */ 1, // ne11, - /* ne12 */ 1, // ne12, - /* ne13 */ 1, // ne13, - /* nb10 */ nb10, - /* nb11 */ nb11, - /* nb12 */ nb12, - /* ne13 */ nb12, // ne12 == 1 - /* ne0 */ ne0, - /* ne1 */ 1, // ne1, - /* nb1 */ nb1, - /* r2 */ 1, - /* r3 */ 1, + args0, shared_values, tgpig, tiitg, From 4af3a879628bce4ab09e7398fa34da3e68b00060 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 08:10:22 +0200 Subject: [PATCH 06/20] cont : pass by reference --- ggml/src/ggml-metal/ggml-metal.metal | 99 ++++++++++++++++------------ 1 file changed, 57 insertions(+), 42 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 81f1aeedcbceb..8e3a9de2388f7 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1629,12 +1629,12 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre // quantizations where the block size is 32. It also does not // guard against the number of rows not being divisible by // N_DST, so this is another explicit assumption of the implementation. -template +template void mul_vec_q_n_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1711,7 +1711,7 @@ kernel void kernel_mul_mv_q4_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -1720,9 +1720,9 @@ kernel void kernel_mul_mv_q4_1_f32( device float * dst, constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -1733,7 +1733,7 @@ kernel void kernel_mul_mv_q5_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -1744,16 +1744,17 @@ kernel void kernel_mul_mv_q5_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } #define NB_Q8_0 8 +template void kernel_mul_mv_q8_0_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1829,17 +1830,17 @@ kernel void kernel_mul_mv_q8_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q8_0_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } #define N_MV_T_T 4 -template +template void kernel_mul_mv_impl( device const char * src0, device const char * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, uint3 tgpig, uint tiisg) { const int64_t r0 = tgpig.x; @@ -1909,7 +1910,7 @@ kernel void kernel_mul_mv( constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_impl( + kernel_mul_mv_impl( src0, src1, dst, @@ -3931,11 +3932,12 @@ kernel void kernel_concat( } } +template void kernel_mul_mv_q2_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4030,14 +4032,15 @@ kernel void kernel_mul_mv_q2_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q3_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4189,14 +4192,15 @@ kernel void kernel_mul_mv_q3_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q4_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4304,14 +4308,15 @@ kernel void kernel_mul_mv_q4_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q5_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4433,14 +4438,15 @@ kernel void kernel_mul_mv_q5_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q6_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4521,16 +4527,17 @@ kernel void kernel_mul_mv_q6_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit +template void kernel_mul_mv_iq2_xxs_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4629,14 +4636,15 @@ kernel void kernel_mul_mv_iq2_xxs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq2_xs_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4745,14 +4753,15 @@ kernel void kernel_mul_mv_iq2_xs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq3_xxs_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4854,14 +4863,15 @@ kernel void kernel_mul_mv_iq3_xxs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq3_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4963,14 +4973,15 @@ kernel void kernel_mul_mv_iq3_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq2_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -5073,14 +5084,15 @@ kernel void kernel_mul_mv_iq2_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq1_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, @@ -5159,11 +5171,12 @@ void kernel_mul_mv_iq1_s_f32_impl( } } +template void kernel_mul_mv_iq1_m_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, @@ -5251,11 +5264,12 @@ void kernel_mul_mv_iq1_m_f32_impl( } } +template void kernel_mul_mv_iq4_nl_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values_i8, uint3 tgpig, uint tiisg, @@ -5338,11 +5352,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( } } +template void kernel_mul_mv_iq4_xs_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values_i8, uint3 tgpig, uint tiisg, @@ -5436,7 +5451,7 @@ kernel void kernel_mul_mv_iq1_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq1_m_f32")]] @@ -5449,7 +5464,7 @@ kernel void kernel_mul_mv_iq1_m_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_nl_f32")]] @@ -5463,7 +5478,7 @@ kernel void kernel_mul_mv_iq4_nl_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_xs_f32")]] @@ -5477,7 +5492,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } template @@ -6062,7 +6077,7 @@ void mmv_fn( impl_fn(src0,(const device float *) src1, dst, args, shared_values, tgpig, tiisg, sgitg); } -typedef decltype(mmv_fn>) mul_mv_impl_fn_t; +typedef decltype(mmv_fn>) mul_mv_impl_fn_t; template kernel void kernel_mul_mv_id( From 481b05df22cacc9d9699f5265cd73752cfc85258 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 08:47:30 +0200 Subject: [PATCH 07/20] cont : args is first argument --- ggml/src/ggml-metal/ggml-metal.m | 52 ++++---- ggml/src/ggml-metal/ggml-metal.metal | 182 +++++++++++++-------------- 2 files changed, 117 insertions(+), 117 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 92506e8772d9a..f4b611b886555 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1981,10 +1981,10 @@ static void ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&args length:sizeof(args) atIndex:3]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; @@ -2185,10 +2185,10 @@ static void ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&args length:sizeof(args) atIndex:3]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || @@ -2503,11 +2503,11 @@ static void ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&args length:sizeof(args) atIndex:4]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; const int64_t _ne1 = 1; const int tgz = dst_rows; @@ -2752,15 +2752,15 @@ static void ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; if (id_src2 != nil) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&args length:sizeof(args) atIndex:4]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -3270,16 +3270,16 @@ static void ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&args length:sizeof(args) atIndex:5]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:5]; if (!use_vec_kernel) { // half8x8 kernel diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 8e3a9de2388f7..77698b5761931 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1629,12 +1629,12 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre // quantizations where the block size is 32. It also does not // guard against the number of rows not being divisible by // N_DST, so this is another explicit assumption of the implementation. -template +template void mul_vec_q_n_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1704,57 +1704,57 @@ void mul_vec_q_n_f32_impl( } kernel void kernel_mul_mv_q4_0_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } #define NB_Q8_0 8 -template +template void kernel_mul_mv_q8_0_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1823,24 +1823,24 @@ void kernel_mul_mv_q8_0_f32_impl( [[host_name("kernel_mul_mv_q8_0_f32")]] kernel void kernel_mul_mv_q8_0_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } #define N_MV_T_T 4 -template +template void kernel_mul_mv_impl( + args_t args, device const char * src0, device const char * src1, device float * dst, - A args, uint3 tgpig, uint tiisg) { const int64_t r0 = tgpig.x; @@ -1904,17 +1904,17 @@ void kernel_mul_mv_impl( template kernel void kernel_mul_mv( + constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_impl( + args, src0, src1, dst, - args, tgpig, tiisg); } @@ -1931,10 +1931,10 @@ template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv< template kernel void kernel_mul_mv_1row( + constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -1987,10 +1987,10 @@ template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kerne // Assumes row size (ne00) is a multiple of 4 template kernel void kernel_mul_mv_l4( + constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -2069,11 +2069,11 @@ static void rope_yarn_corr_dims( template kernel void kernel_rope_norm( + constant ggml_metal_kargs_rope & args, device const char * src0, device const char * src1, device const char * src2, device char * dst, - constant ggml_metal_kargs_rope & args, ushort tiitg[[thread_index_in_threadgroup]], ushort3 tptg [[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { @@ -2122,11 +2122,11 @@ kernel void kernel_rope_norm( template kernel void kernel_rope_neox( + constant ggml_metal_kargs_rope & args, device const char * src0, device const char * src1, device const char * src2, device char * dst, - constant ggml_metal_kargs_rope & args, ushort tiitg[[thread_index_in_threadgroup]], ushort3 tptg [[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { @@ -2563,13 +2563,13 @@ template< short KV = 8, // key/value processed per each simdgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext( + constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, device const char * k, device const char * v, device const char * mask, device char * dst, - constant ggml_metal_kargs_flash_attn_ext & args, - threadgroup half * shared [[threadgroup(0)]], + threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], @@ -3058,13 +3058,13 @@ template< short Q = 1, // queries per threadgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( + constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, device const char * k, device const char * v, device const char * mask, device char * dst, - constant ggml_metal_kargs_flash_attn_ext & args, - threadgroup half * shared [[threadgroup(0)]], + threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], @@ -3932,12 +3932,12 @@ kernel void kernel_concat( } } -template +template void kernel_mul_mv_q2_K_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4024,23 +4024,23 @@ void kernel_mul_mv_q2_K_f32_impl( [[host_name("kernel_mul_mv_q2_K_f32")]] kernel void kernel_mul_mv_q2_K_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q3_K_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4184,23 +4184,23 @@ void kernel_mul_mv_q3_K_f32_impl( [[host_name("kernel_mul_mv_q3_K_f32")]] kernel void kernel_mul_mv_q3_K_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q4_K_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4300,23 +4300,23 @@ void kernel_mul_mv_q4_K_f32_impl( [[host_name("kernel_mul_mv_q4_K_f32")]] kernel void kernel_mul_mv_q4_K_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q5_K_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4430,23 +4430,23 @@ void kernel_mul_mv_q5_K_f32_impl( [[host_name("kernel_mul_mv_q5_K_f32")]] kernel void kernel_mul_mv_q5_K_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q6_K_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4519,25 +4519,25 @@ void kernel_mul_mv_q6_K_f32_impl( [[host_name("kernel_mul_mv_q6_K_f32")]] kernel void kernel_mul_mv_q6_K_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit -template +template void kernel_mul_mv_iq2_xxs_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4627,24 +4627,24 @@ void kernel_mul_mv_iq2_xxs_f32_impl( [[host_name("kernel_mul_mv_iq2_xxs_f32")]] kernel void kernel_mul_mv_iq2_xxs_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_xs_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4744,24 +4744,24 @@ void kernel_mul_mv_iq2_xs_f32_impl( [[host_name("kernel_mul_mv_iq2_xs_f32")]] kernel void kernel_mul_mv_iq2_xs_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_xxs_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4854,24 +4854,24 @@ void kernel_mul_mv_iq3_xxs_f32_impl( [[host_name("kernel_mul_mv_iq3_xxs_f32")]] kernel void kernel_mul_mv_iq3_xxs_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_s_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4964,24 +4964,24 @@ void kernel_mul_mv_iq3_s_f32_impl( [[host_name("kernel_mul_mv_iq3_s_f32")]] kernel void kernel_mul_mv_iq3_s_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_s_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -5075,24 +5075,24 @@ void kernel_mul_mv_iq2_s_f32_impl( [[host_name("kernel_mul_mv_iq2_s_f32")]] kernel void kernel_mul_mv_iq2_s_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq1_s_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, @@ -5171,12 +5171,12 @@ void kernel_mul_mv_iq1_s_f32_impl( } } -template +template void kernel_mul_mv_iq1_m_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, @@ -5264,12 +5264,12 @@ void kernel_mul_mv_iq1_m_f32_impl( } } -template +template void kernel_mul_mv_iq4_nl_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values_i8, uint3 tgpig, uint tiisg, @@ -5352,12 +5352,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( } } -template +template void kernel_mul_mv_iq4_xs_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values_i8, uint3 tgpig, uint tiisg, @@ -5443,56 +5443,56 @@ void kernel_mul_mv_iq4_xs_f32_impl( [[host_name("kernel_mul_mv_iq1_s_f32")]] kernel void kernel_mul_mv_iq1_s_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq1_m_f32")]] kernel void kernel_mul_mv_iq1_m_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_nl_f32")]] kernel void kernel_mul_mv_iq4_nl_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_xs_f32")]] kernel void kernel_mul_mv_iq4_xs_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } template @@ -5597,10 +5597,10 @@ kernel void kernel_get_rows_i32( // each block_q contains 16*nl weights template kernel void kernel_mul_mm( + constant ggml_metal_kargs_mul_mm & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_mul_mm & args, threadgroup char * shared_memory [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiitg[[thread_index_in_threadgroup]], @@ -6032,18 +6032,18 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel // typedef void (kernel_mul_mv_impl_t)( + ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, device float * dst, - ggml_metal_kargs_mul_mv args, uint3 tgpig, uint tiisg); typedef void (kernel_mul_mv2_impl_t)( + ggml_metal_kargs_mul_mv args, device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -6051,41 +6051,41 @@ typedef void (kernel_mul_mv2_impl_t)( template void mmv_fn( + ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, device float * dst, - ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiitg, uint tiisg, uint sgitg) { - impl_fn(src0, src1, dst, args, tgpig, tiisg); + impl_fn(args, src0, src1, dst, tgpig, tiisg); } template void mmv_fn( + ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, device float * dst, - ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiitg, uint tiisg, uint sgitg) { - impl_fn(src0,(const device float *) src1, dst, args, shared_values, tgpig, tiisg, sgitg); + impl_fn(args, src0,(const device float *) src1, dst, shared_values, tgpig, tiisg, sgitg); } typedef decltype(mmv_fn>) mul_mv_impl_fn_t; template kernel void kernel_mul_mv_id( + constant ggml_metal_kargs_mul_mv_id & args, device const char * src0s, device const char * src1, device float * dst, device const char * ids, - constant ggml_metal_kargs_mul_mv_id & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], @@ -6130,10 +6130,10 @@ kernel void kernel_mul_mv_id( }; impl_fn( + args0, /* src0 */ src0_cur, /* src1 */ src1_cur, /* dst */ dst_cur, - args0, shared_values, tgpig, tiitg, From d2a055059e8102fe32092263131f389626125091 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 09:26:53 +0200 Subject: [PATCH 08/20] cont : use char ptr --- ggml/src/ggml-metal/ggml-metal.metal | 395 +++++++++++++++------------ 1 file changed, 216 insertions(+), 179 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 77698b5761931..d586c7b614ed0 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1632,9 +1632,9 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre template void mul_vec_q_n_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1653,8 +1653,8 @@ void mul_vec_q_n_f32_impl( //const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - //device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows device const block_q_type * ax[nr]; @@ -1695,19 +1695,22 @@ void mul_vec_q_n_f32_impl( yb += QK4_0 * 16; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < args.ne01) { - dst[im*args.ne0*args.ne1 + r1*args.ne0 + first_row + row] = tot; + dst_f32[first_row + row] = tot; } } } kernel void kernel_mul_mv_q4_0_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1716,9 +1719,9 @@ kernel void kernel_mul_mv_q4_0_f32( kernel void kernel_mul_mv_q4_1_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1727,9 +1730,9 @@ kernel void kernel_mul_mv_q4_1_f32( kernel void kernel_mul_mv_q5_0_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1738,9 +1741,9 @@ kernel void kernel_mul_mv_q5_0_f32( kernel void kernel_mul_mv_q5_1_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1752,9 +1755,9 @@ kernel void kernel_mul_mv_q5_1_f32( template void kernel_mul_mv_q8_0_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1776,8 +1779,8 @@ void kernel_mul_mv_q8_0_f32_impl( //const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - //device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows device const block_q8_0 * ax[nr]; @@ -1813,10 +1816,12 @@ void kernel_mul_mv_q8_0_f32_impl( yb += NB_Q8_0 * nw; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < args.ne01) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = tot; + dst_f32[first_row + row] = tot; } } } @@ -1824,9 +1829,9 @@ void kernel_mul_mv_q8_0_f32_impl( [[host_name("kernel_mul_mv_q8_0_f32")]] kernel void kernel_mul_mv_q8_0_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1838,9 +1843,9 @@ kernel void kernel_mul_mv_q8_0_f32( template void kernel_mul_mv_impl( args_t args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig, uint tiisg) { const int64_t r0 = tgpig.x; @@ -1854,6 +1859,8 @@ void kernel_mul_mv_impl( device const T0 * x = (device const T0 *) (src0 + offset0); + device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1; + if (args.ne00 < 128) { for (int row = 0; row < N_MV_T_T; ++row) { int r1 = rb + row; @@ -1872,7 +1879,7 @@ void kernel_mul_mv_impl( float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; + dst_f32[r1*args.ne0 + r0] = all_sum; } } } else { @@ -1896,7 +1903,7 @@ void kernel_mul_mv_impl( float all_sum = simd_sum(sumf); if (tiisg == 0) { for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; + dst_f32[r1*args.ne0 + r0] = all_sum; } } } @@ -1905,9 +1912,9 @@ void kernel_mul_mv_impl( template kernel void kernel_mul_mv( constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_impl( @@ -3935,9 +3942,9 @@ kernel void kernel_concat( template void kernel_mul_mv_q2_K_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -3956,8 +3963,8 @@ void kernel_mul_mv_q2_K_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4014,10 +4021,12 @@ void kernel_mul_mv_q2_K_f32_impl( y4 += 4 * QK_K; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } @@ -4025,9 +4034,9 @@ void kernel_mul_mv_q2_K_f32_impl( [[host_name("kernel_mul_mv_q2_K_f32")]] kernel void kernel_mul_mv_q2_K_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -4038,9 +4047,9 @@ kernel void kernel_mul_mv_q2_K_f32( template void kernel_mul_mv_q3_K_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4060,8 +4069,8 @@ void kernel_mul_mv_q3_K_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0); - device const float * yy = (device const float *) ((device char *) src1 + offset1); + device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); float yl[32]; @@ -4175,9 +4184,12 @@ void kernel_mul_mv_q3_K_f32_impl( const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); sumf1[row] = simd_sum(sumf); } + + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + if (tiisg == 0) { for (int row = 0; row < 2; ++row) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = sumf1[row]; + dst_f32[first_row + row] = sumf1[row]; } } } @@ -4185,9 +4197,9 @@ void kernel_mul_mv_q3_K_f32_impl( [[host_name("kernel_mul_mv_q3_K_f32")]] kernel void kernel_mul_mv_q3_K_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -4198,9 +4210,9 @@ kernel void kernel_mul_mv_q3_K_f32( template void kernel_mul_mv_q4_K_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4228,8 +4240,8 @@ void kernel_mul_mv_q4_K_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[16]; float yh[16]; @@ -4290,10 +4302,12 @@ void kernel_mul_mv_q4_K_f32_impl( y4 += 4 * QK_K; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } @@ -4301,9 +4315,9 @@ void kernel_mul_mv_q4_K_f32_impl( [[host_name("kernel_mul_mv_q4_K_f32")]] kernel void kernel_mul_mv_q4_K_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -4314,9 +4328,9 @@ kernel void kernel_mul_mv_q4_K_f32( template void kernel_mul_mv_q5_K_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4336,8 +4350,8 @@ void kernel_mul_mv_q5_K_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0); - device const float * yy = (device const float *) ((device char *) src1 + offset1); + device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); float sumf[2]={0.f}; @@ -4420,10 +4434,12 @@ void kernel_mul_mv_q5_K_f32_impl( y1 += 4 * QK_K; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < 2; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = tot; + dst_f32[first_row + row] = tot; } } } @@ -4431,9 +4447,9 @@ void kernel_mul_mv_q5_K_f32_impl( [[host_name("kernel_mul_mv_q5_K_f32")]] kernel void kernel_mul_mv_q5_K_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -4444,9 +4460,9 @@ kernel void kernel_mul_mv_q5_K_f32( template void kernel_mul_mv_q6_K_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4471,8 +4487,8 @@ void kernel_mul_mv_q6_K_f32_impl( const uint offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0); - device const float * yy = (device const float *) ((device char *) src1 + offset1); + device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); float sumf = 0; @@ -4511,18 +4527,20 @@ void kernel_mul_mv_q6_K_f32_impl( } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + const float tot = simd_sum(sumf); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + row] = tot; + dst_f32[row] = tot; } } [[host_name("kernel_mul_mv_q6_K_f32")]] kernel void kernel_mul_mv_q6_K_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -4535,9 +4553,9 @@ kernel void kernel_mul_mv_q6_K_f32( template void kernel_mul_mv_iq2_xxs_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4556,8 +4574,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4617,10 +4635,12 @@ void kernel_mul_mv_iq2_xxs_f32_impl( y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = all_sum * 0.25f; } } } @@ -4628,9 +4648,9 @@ void kernel_mul_mv_iq2_xxs_f32_impl( [[host_name("kernel_mul_mv_iq2_xxs_f32")]] kernel void kernel_mul_mv_iq2_xxs_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -4642,9 +4662,9 @@ kernel void kernel_mul_mv_iq2_xxs_f32( template void kernel_mul_mv_iq2_xs_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4663,8 +4683,8 @@ void kernel_mul_mv_iq2_xs_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4734,10 +4754,12 @@ void kernel_mul_mv_iq2_xs_f32_impl( y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = all_sum * 0.25f; } } } @@ -4745,9 +4767,9 @@ void kernel_mul_mv_iq2_xs_f32_impl( [[host_name("kernel_mul_mv_iq2_xs_f32")]] kernel void kernel_mul_mv_iq2_xs_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -4759,9 +4781,9 @@ kernel void kernel_mul_mv_iq2_xs_f32( template void kernel_mul_mv_iq3_xxs_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4780,8 +4802,8 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4844,10 +4866,12 @@ void kernel_mul_mv_iq3_xxs_f32_impl( y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.5f; + dst_f32[first_row + row] = all_sum * 0.5f; } } } @@ -4855,9 +4879,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl( [[host_name("kernel_mul_mv_iq3_xxs_f32")]] kernel void kernel_mul_mv_iq3_xxs_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -4869,9 +4893,9 @@ kernel void kernel_mul_mv_iq3_xxs_f32( template void kernel_mul_mv_iq3_s_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4890,8 +4914,8 @@ void kernel_mul_mv_iq3_s_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4954,10 +4978,12 @@ void kernel_mul_mv_iq3_s_f32_impl( y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } @@ -4965,9 +4991,9 @@ void kernel_mul_mv_iq3_s_f32_impl( [[host_name("kernel_mul_mv_iq3_s_f32")]] kernel void kernel_mul_mv_iq3_s_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -4979,9 +5005,9 @@ kernel void kernel_mul_mv_iq3_s_f32( template void kernel_mul_mv_iq2_s_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -5000,8 +5026,8 @@ void kernel_mul_mv_iq2_s_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -5065,10 +5091,12 @@ void kernel_mul_mv_iq2_s_f32_impl( y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = all_sum * 0.25f; } } } @@ -5076,9 +5104,9 @@ void kernel_mul_mv_iq2_s_f32_impl( [[host_name("kernel_mul_mv_iq2_s_f32")]] kernel void kernel_mul_mv_iq2_s_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -5090,9 +5118,9 @@ kernel void kernel_mul_mv_iq2_s_f32( template void kernel_mul_mv_iq1_s_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, @@ -5111,8 +5139,8 @@ void kernel_mul_mv_iq1_s_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -5163,10 +5191,12 @@ void kernel_mul_mv_iq1_s_f32_impl( y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } @@ -5174,9 +5204,9 @@ void kernel_mul_mv_iq1_s_f32_impl( template void kernel_mul_mv_iq1_m_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, @@ -5195,8 +5225,8 @@ void kernel_mul_mv_iq1_m_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -5256,10 +5286,12 @@ void kernel_mul_mv_iq1_m_f32_impl( y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } @@ -5267,9 +5299,9 @@ void kernel_mul_mv_iq1_m_f32_impl( template void kernel_mul_mv_iq4_nl_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values_i8, uint3 tgpig, uint tiisg, @@ -5288,8 +5320,8 @@ void kernel_mul_mv_iq4_nl_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); const int ix = tiisg/2; // 0...15 const int it = tiisg%2; // 0 or 1 @@ -5344,10 +5376,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( yb += 16 * QK4_NL; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } @@ -5355,9 +5389,9 @@ void kernel_mul_mv_iq4_nl_f32_impl( template void kernel_mul_mv_iq4_xs_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values_i8, uint3 tgpig, uint tiisg, @@ -5376,8 +5410,8 @@ void kernel_mul_mv_iq4_xs_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); const int ix = tiisg/16; // 0 or 1 const int it = tiisg%16; // 0...15 @@ -5433,10 +5467,12 @@ void kernel_mul_mv_iq4_xs_f32_impl( yb += 2 * QK_K; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < 2; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } @@ -5444,9 +5480,9 @@ void kernel_mul_mv_iq4_xs_f32_impl( [[host_name("kernel_mul_mv_iq1_s_f32")]] kernel void kernel_mul_mv_iq1_s_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -5457,9 +5493,9 @@ kernel void kernel_mul_mv_iq1_s_f32( [[host_name("kernel_mul_mv_iq1_m_f32")]] kernel void kernel_mul_mv_iq1_m_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -5470,9 +5506,9 @@ kernel void kernel_mul_mv_iq1_m_f32( [[host_name("kernel_mul_mv_iq4_nl_f32")]] kernel void kernel_mul_mv_iq4_nl_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -5484,9 +5520,9 @@ kernel void kernel_mul_mv_iq4_nl_f32( [[host_name("kernel_mul_mv_iq4_xs_f32")]] kernel void kernel_mul_mv_iq4_xs_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -6033,17 +6069,17 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel typedef void (kernel_mul_mv_impl_t)( ggml_metal_kargs_mul_mv args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig, uint tiisg); typedef void (kernel_mul_mv2_impl_t)( ggml_metal_kargs_mul_mv args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -6052,9 +6088,9 @@ typedef void (kernel_mul_mv2_impl_t)( template void mmv_fn( ggml_metal_kargs_mul_mv args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiitg, @@ -6066,15 +6102,15 @@ void mmv_fn( template void mmv_fn( ggml_metal_kargs_mul_mv args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiitg, uint tiisg, uint sgitg) { - impl_fn(args, src0,(const device float *) src1, dst, shared_values, tgpig, tiisg, sgitg); + impl_fn(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } typedef decltype(mmv_fn>) mul_mv_impl_fn_t; @@ -6082,10 +6118,10 @@ typedef decltype(mmv_fn kernel void kernel_mul_mv_id( constant ggml_metal_kargs_mul_mv_id & args, - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, + device const char * src0s, + device const char * src1, + device char * dst, + device const char * ids, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], @@ -6106,7 +6142,8 @@ kernel void kernel_mul_mv_id( device const char * src0_cur = src0s + i02*args.nb02; device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12; - device float * dst_cur = dst + i1*args.ne0 + i2*args.ne1*args.ne0; + + device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float); ggml_metal_kargs_mul_mv args0 = { /*.ne00 =*/ args.ne00, From f759814c667028f444e90dd4f8da1b36a3ed2950 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 09:45:06 +0200 Subject: [PATCH 09/20] cont : shmem style --- ggml/src/ggml-metal/ggml-metal.metal | 215 +++++++++++++-------------- 1 file changed, 107 insertions(+), 108 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index d586c7b614ed0..a19df23a113cd 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1635,7 +1635,7 @@ void mul_vec_q_n_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -1758,7 +1758,7 @@ void kernel_mul_mv_q8_0_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -2576,7 +2576,7 @@ kernel void kernel_flash_attn_ext( device const char * v, device const char * mask, device char * dst, - threadgroup half * shared [[threadgroup(0)]], + threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], @@ -2596,17 +2596,17 @@ kernel void kernel_flash_attn_ext( const short TS = nsg*SH; // shared memory size per query in (s_t == float) const short T = D + 2*TS; // shared memory size per query in (half) - threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t - threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation - threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // same as above but in o4_t - threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*D); // reuse query data for accumulation + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*D); // same as above but in o4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix - threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory - threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t + threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory + threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t - threadgroup v_t * sv = (threadgroup v_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory - threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t + threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory + threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) o8x8_t lo[D8]; @@ -3071,7 +3071,7 @@ kernel void kernel_flash_attn_ext_vec( device const char * v, device const char * mask, device char * dst, - threadgroup half * shared [[threadgroup(0)]], + threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], @@ -3090,13 +3090,13 @@ kernel void kernel_flash_attn_ext_vec( const short T = D + nsg*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t - threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t - threadgroup s_t * ss = (threadgroup s_t *) (shared + sgitg*SH + Q*D); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + sgitg*SH + Q*D); // same as above but in s4_t - threadgroup half * sm = (threadgroup half *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask - threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t + threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 + 0*D); // same as above but in q4x4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*D); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*D); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask + threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) o4x4_t lo[D16/NL]; @@ -3945,7 +3945,7 @@ void kernel_mul_mv_q2_K_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4050,7 +4050,7 @@ void kernel_mul_mv_q3_K_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4213,7 +4213,7 @@ void kernel_mul_mv_q4_K_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4331,7 +4331,7 @@ void kernel_mul_mv_q5_K_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4463,7 +4463,7 @@ void kernel_mul_mv_q6_K_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4556,7 +4556,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4582,15 +4582,15 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int nb32 = nb * (QK_K / 32); - threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256); { int nval = 4; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i]; nval = 2; pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -4620,8 +4620,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl( float sum = 0; for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]); - const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]); + const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; for (int j = 0; j < 8; ++j) { sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } @@ -4651,12 +4651,11 @@ kernel void kernel_mul_mv_iq2_xxs_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -4665,7 +4664,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4691,15 +4690,15 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int nb32 = nb * (QK_K / 32); - threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512); + threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 512); { int nval = 8; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i]; nval = 2; pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -4731,15 +4730,15 @@ void kernel_mul_mv_iq2_xs_f32_impl( float sum1 = 0, sum2 = 0; for (int l = 0; l < 2; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); - const uint8_t signs = shared_signs[(q2[l] >> 9)]; + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); + const uint8_t signs = ssigns[(q2[l] >> 9)]; for (int j = 0; j < 8; ++j) { sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } for (int l = 2; l < 4; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); - const uint8_t signs = shared_signs[(q2[l] >> 9)]; + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); + const uint8_t signs = ssigns[(q2[l] >> 9)]; for (int j = 0; j < 8; ++j) { sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } @@ -4770,12 +4769,12 @@ kernel void kernel_mul_mv_iq2_xs_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -4784,7 +4783,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4810,15 +4809,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const int nb32 = nb * (QK_K / 32); - threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256); { int nval = 4; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i]; nval = 2; pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -4848,9 +4847,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl( float2 sum = {0}; for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]); - const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]); - const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]); + const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; for (int j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); @@ -4882,12 +4881,12 @@ kernel void kernel_mul_mv_iq3_xxs_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -4896,7 +4895,7 @@ void kernel_mul_mv_iq3_s_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4922,11 +4921,11 @@ void kernel_mul_mv_iq3_s_f32_impl( const int nb32 = nb * (QK_K / 32); - threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; + threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem; { int nval = 8; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -4957,8 +4956,8 @@ void kernel_mul_mv_iq3_s_f32_impl( float2 sum = {0}; for (int l = 0; l < 4; ++l) { - const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values; - const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values; + const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues; + const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues; const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); for (int j = 0; j < 4; ++j) { @@ -4994,12 +4993,12 @@ kernel void kernel_mul_mv_iq3_s_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -5008,7 +5007,7 @@ void kernel_mul_mv_iq2_s_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -5034,11 +5033,11 @@ void kernel_mul_mv_iq2_s_f32_impl( const int nb32 = nb * (QK_K / 32); - //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem; //{ // int nval = 32; // int pos = (32*sgitg + tiisg)*nval; - // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i]; + // for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i]; // threadgroup_barrier(mem_flags::mem_threadgroup); //} @@ -5070,8 +5069,8 @@ void kernel_mul_mv_iq2_s_f32_impl( float2 sum = {0}; for (int l = 0; l < 2; ++l) { - //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); - //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); for (int j = 0; j < 8; ++j) { @@ -5107,12 +5106,12 @@ kernel void kernel_mul_mv_iq2_s_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -5121,7 +5120,7 @@ void kernel_mul_mv_iq1_s_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_value, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -5207,7 +5206,7 @@ void kernel_mul_mv_iq1_m_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_value, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -5302,12 +5301,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values_i8, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK4_NL; const int r0 = tgpig.x; const int r1 = tgpig.y; @@ -5326,7 +5325,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( const int ix = tiisg/2; // 0...15 const int it = tiisg%2; // 0 or 1 - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; @@ -5354,16 +5353,16 @@ void kernel_mul_mv_iq4_nl_f32_impl( aux32[0] = q4[0] | (q4[1] << 16); aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[0] * qf1; acc2 += yl[1] * qf2; aux32[0] = q4[2] | (q4[3] << 16); aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[2] * qf1; acc2 += yl[3] * qf2; @@ -5392,12 +5391,12 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values_i8, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; @@ -5418,7 +5417,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int ib = it/2; const int il = it%2; - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; @@ -5445,15 +5444,15 @@ void kernel_mul_mv_iq4_xs_f32_impl( aux32[0] = q4[0] & 0x0f0f0f0f; aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[0] * qf1; acc2 += yl[1] * qf2; aux32[0] = q4[1] & 0x0f0f0f0f; aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[2] * qf1; acc2 += yl[3] * qf2; @@ -5509,12 +5508,12 @@ kernel void kernel_mul_mv_iq4_nl_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_xs_f32")]] @@ -5523,12 +5522,12 @@ kernel void kernel_mul_mv_iq4_xs_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -6080,7 +6079,7 @@ typedef void (kernel_mul_mv2_impl_t)( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg); @@ -6091,11 +6090,11 @@ void mmv_fn( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { + threadgroup char * shmem, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { impl_fn(args, src0, src1, dst, tgpig, tiisg); } @@ -6105,12 +6104,12 @@ void mmv_fn( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { - impl_fn(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + threadgroup char * shmem, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { + impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } typedef decltype(mmv_fn>) mul_mv_impl_fn_t; @@ -6122,11 +6121,11 @@ kernel void kernel_mul_mv_id( device const char * src1, device char * dst, device const char * ids, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int iid1 = tgpig.z/args.nei0; const int idx = tgpig.z%args.nei0; @@ -6171,7 +6170,7 @@ kernel void kernel_mul_mv_id( /* src0 */ src0_cur, /* src1 */ src1_cur, /* dst */ dst_cur, - shared_values, + shmem, tgpig, tiitg, tiisg, From cd89d1a877b468d9b5f063a56ee6a6550ab43ff7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 09:57:41 +0200 Subject: [PATCH 10/20] cont : thread counters style --- ggml/src/ggml-metal/ggml-metal.metal | 281 ++++++++++++++------------- 1 file changed, 141 insertions(+), 140 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index a19df23a113cd..0bdb04e814c2c 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1636,9 +1636,9 @@ void mul_vec_q_n_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK4_0; const int r0 = tgpig.x; @@ -1711,9 +1711,9 @@ kernel void kernel_mul_mv_q4_0_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -1722,9 +1722,9 @@ kernel void kernel_mul_mv_q4_1_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -1733,9 +1733,9 @@ kernel void kernel_mul_mv_q5_0_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -1744,9 +1744,9 @@ kernel void kernel_mul_mv_q5_1_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -1759,9 +1759,9 @@ void kernel_mul_mv_q8_0_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nr = N_DST; const int nsg = N_SIMDGROUP; const int nw = N_SIMDWIDTH; @@ -1771,7 +1771,7 @@ void kernel_mul_mv_q8_0_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr; + const int first_row = (r0*nsg + sgitg)*nr; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -1791,12 +1791,12 @@ void kernel_mul_mv_q8_0_f32_impl( } float yl[NB_Q8_0]; - float sumf[nr]={0.f}; + float sumf[nr] = { 0.f }; const int ix = tiisg/4; const int il = tiisg%4; - device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; + device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; // each thread in a SIMD group deals with NB_Q8_0 quants at a time for (int ib = ix; ib < nb; ib += nw/4) { @@ -1805,7 +1805,7 @@ void kernel_mul_mv_q8_0_f32_impl( } for (int row = 0; row < nr; row++) { - device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il; + device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; float sumq = 0.f; for (int iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; @@ -1813,13 +1813,14 @@ void kernel_mul_mv_q8_0_f32_impl( sumf[row] += sumq*ax[row][ib].d; } - yb += NB_Q8_0 * nw; + yb += nw*NB_Q8_0; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < args.ne01) { dst_f32[first_row + row] = tot; } @@ -1832,9 +1833,9 @@ kernel void kernel_mul_mv_q8_0_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -1846,8 +1847,8 @@ void kernel_mul_mv_impl( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig, - uint tiisg) { + uint3 tgpig, + ushort tiisg) { const int64_t r0 = tgpig.x; const int64_t rb = tgpig.y*N_MV_T_T; const int64_t im = tgpig.z; @@ -1915,8 +1916,8 @@ kernel void kernel_mul_mv( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_impl( args, src0, @@ -1942,8 +1943,8 @@ kernel void kernel_mul_mv_1row( device const char * src0, device const char * src1, device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -1998,8 +1999,8 @@ kernel void kernel_mul_mv_l4( device const char * src0, device const char * src1, device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { const int nrows = args.ne11; const int64_t r0 = tgpig.x; @@ -3946,9 +3947,9 @@ void kernel_mul_mv_q2_K_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -4037,9 +4038,9 @@ kernel void kernel_mul_mv_q2_K_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -4051,9 +4052,9 @@ void kernel_mul_mv_q3_K_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; @@ -4200,9 +4201,9 @@ kernel void kernel_mul_mv_q3_K_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -4214,9 +4215,9 @@ void kernel_mul_mv_q4_K_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; @@ -4318,9 +4319,9 @@ kernel void kernel_mul_mv_q4_K_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -4332,9 +4333,9 @@ void kernel_mul_mv_q5_K_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; @@ -4450,9 +4451,9 @@ kernel void kernel_mul_mv_q5_K_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -4464,9 +4465,9 @@ void kernel_mul_mv_q6_K_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const uint8_t kmask1 = 0x03; const uint8_t kmask2 = 0x0C; @@ -4541,9 +4542,9 @@ kernel void kernel_mul_mv_q6_K_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -4557,9 +4558,9 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -4652,9 +4653,9 @@ kernel void kernel_mul_mv_iq2_xxs_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -4665,9 +4666,9 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -4770,9 +4771,9 @@ kernel void kernel_mul_mv_iq2_xs_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -4784,9 +4785,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -4882,9 +4883,9 @@ kernel void kernel_mul_mv_iq3_xxs_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -4896,9 +4897,9 @@ void kernel_mul_mv_iq3_s_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -4994,9 +4995,9 @@ kernel void kernel_mul_mv_iq3_s_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -5008,9 +5009,9 @@ void kernel_mul_mv_iq2_s_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -5107,9 +5108,9 @@ kernel void kernel_mul_mv_iq2_s_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -5121,9 +5122,9 @@ void kernel_mul_mv_iq1_s_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -5207,9 +5208,9 @@ void kernel_mul_mv_iq1_m_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -5302,9 +5303,9 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK4_NL; @@ -5392,9 +5393,9 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK_K; @@ -5482,9 +5483,9 @@ kernel void kernel_mul_mv_iq1_s_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -5495,9 +5496,9 @@ kernel void kernel_mul_mv_iq1_m_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -5509,9 +5510,9 @@ kernel void kernel_mul_mv_iq4_nl_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -5523,9 +5524,9 @@ kernel void kernel_mul_mv_iq4_xs_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -5636,13 +5637,13 @@ kernel void kernel_mul_mm( device const char * src0, device const char * src1, device char * dst, - threadgroup char * shared_memory [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiitg[[thread_index_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup T * sa = (threadgroup T *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + threadgroup T * sa = (threadgroup T *)(shmem); + threadgroup float * sb = (threadgroup float *)(shmem + 4096); const int r0 = tgpig.y; const int r1 = tgpig.x; @@ -5737,7 +5738,7 @@ kernel void kernel_mul_mm( } else { // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *) shared_memory) \ + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M; for (short i = 0; i < 8; i++) { simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); @@ -6071,8 +6072,8 @@ typedef void (kernel_mul_mv_impl_t)( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig, - uint tiisg); + uint3 tgpig, + ushort tiisg); typedef void (kernel_mul_mv2_impl_t)( ggml_metal_kargs_mul_mv args, @@ -6080,9 +6081,9 @@ typedef void (kernel_mul_mv2_impl_t)( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg); + uint3 tgpig, + ushort tiisg, + ushort sgitg); template void mmv_fn( @@ -6091,10 +6092,10 @@ void mmv_fn( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiitg, + ushort tiisg, + ushort sgitg) { impl_fn(args, src0, src1, dst, tgpig, tiisg); } @@ -6105,10 +6106,10 @@ void mmv_fn( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiitg, + ushort tiisg, + ushort sgitg) { impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -6122,10 +6123,10 @@ kernel void kernel_mul_mv_id( device char * dst, device const char * ids, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const int iid1 = tgpig.z/args.nei0; const int idx = tgpig.z%args.nei0; From ec18f96891ccdaabfe42c21d440591632679be8b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 10:32:15 +0200 Subject: [PATCH 11/20] cont : mul mm id ggml-ci --- ggml/src/ggml-common.h | 18 ++++ ggml/src/ggml-metal/ggml-metal.m | 43 ++++---- ggml/src/ggml-metal/ggml-metal.metal | 153 ++++++++++++++------------- 3 files changed, 119 insertions(+), 95 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index fcf0d997c52af..a207032693539 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -509,6 +509,24 @@ typedef struct { int16_t r3; } ggml_metal_kargs_mul_mv; +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; +} ggml_metal_kargs_mul_mm_id; + typedef struct { int32_t nei0; int32_t nei1; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index f4b611b886555..fb697b7f77bb6 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2301,27 +2301,30 @@ static void ggml_metal_encode_node( default: GGML_ABORT("MUL_MAT_ID not implemented"); } + ggml_metal_kargs_mul_mm_id args = { + /*.nei0 =*/ ne20, + /*.nei1 =*/ ne21, + /*.nbi1 =*/ nb21, + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 0bdb04e814c2c..e0b4b75439927 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -5769,31 +5769,32 @@ kernel void kernel_mul_mm( } // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids +// TODO: this kernel needs to be reimplemented from scratch for better performance template void kernel_mul_mm_id_impl( - device const uchar * src0, - device const uchar * src1, + int32_t ne00, + int32_t ne02, + uint64_t nb01, + uint64_t nb02, + int32_t ne11, + int32_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int32_t ne0, + int32_t ne1, + int64_t ne0ne1, + device const char * src0, + device const char * src1, threadgroup ushort2 * rowids, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - int64_t ne1, - int64_t ne0ne1, - threadgroup uchar * shared_memory, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - threadgroup half * sa = (threadgroup half *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + device char * dst, + threadgroup char * shmem, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup half * sa = (threadgroup half *)(shmem); + threadgroup float * sb = (threadgroup float *)(shmem + 4096); const uint r0 = tgpig.y; const uint r1 = tgpig.x; @@ -5810,9 +5811,9 @@ void kernel_mul_mm_id_impl( simdgroup_half8x8 ma[4]; simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; + simdgroup_float8x8 mc[8]; for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); + mc[i] = make_filled_simdgroup_matrix(0.f); } short il = (tiitg % THREAD_PER_ROW); @@ -5850,11 +5851,14 @@ void kernel_mul_mm_id_impl( threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + #pragma unroll(BLOCK_SIZE_K/8) for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) for (int i = 0; i < 4; i++) { simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) for (int i = 0; i < 2; i++) { simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } @@ -5862,29 +5866,42 @@ void kernel_mul_mm_id_impl( lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + #pragma unroll(8) for (int i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); } } } { threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); } threadgroup_barrier(mem_flags::mem_threadgroup); - device float * C = dst + (BLOCK_SIZE_M * r0); if (sgitg == 0) { for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; - int joff = jid[0] * ne0 + jid[1] * ne0ne1; - for (int i = 0; i < n_rows; i++) { - *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M); + int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1; + + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = 0; + for (; i < n_rows/4; i++) { + *(D4 + i) = *(C4 + i); + } + + i *= 4; + for (; i < n_rows; i++) { + *(D + i) = *(C + i); } } } @@ -5893,48 +5910,34 @@ void kernel_mul_mm_id_impl( template kernel void kernel_mul_mm_id( - device const uchar * src0s, - device const uchar * src1, - device float * dst, - device const uchar * ids, - constant int64_t & nei0, - constant int64_t & nei1, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mm_id & args, + device const char * src0s, + device const char * src1, + device char * dst, + device const char * ids, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const int32_t i02 = tgpig.z; + tgpig.z = 0; - device const uchar * src0 = src0s + i02*nb02; + device const char * src0 = src0s + i02*args.nb02; // row indices - threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192); // TODO: parallelize this loop int64_t _ne1 = 0; - for (ushort ii1 = 0; ii1 < nei1; ii1++) { - for (ushort ii0 = 0; ii0 < nei0; ii0++) { - int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; + for (ushort ii1 = 0; ii1 < args.nei1; ii1++) { + for (ushort ii0 = 0; ii0 < args.nei0; ii0++) { + int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0]; if (id == i02) { - //if (tiitg == 0) { + if (tiitg == 0) { rowids[_ne1] = ushort2(ii0, ii1); - //} + } _ne1++; } } @@ -5943,23 +5946,23 @@ kernel void kernel_mul_mm_id( threadgroup_barrier(mem_flags::mem_threadgroup); kernel_mul_mm_id_impl( + args.ne00, + args.ne02, + args.nb01, + args.nb02, + args.ne11, + args.ne12, + args.nb10, + args.nb11, + args.nb12, + args.ne0, + _ne1, + (int64_t)args.ne0*args.ne1, src0, src1, rowids, dst, - ne00, - ne02, - nb01, - nb02, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - _ne1, - ne0*ne1, - shared_memory, + shmem, tgpig, tiitg, sgitg); From 1a8f8df35d0f4ec9849df96d33e14ae1256956ec Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 11:05:10 +0200 Subject: [PATCH 12/20] cont : int safety + register optimizations ggml-ci --- ggml/src/ggml-metal/ggml-metal.metal | 226 +++++++++++++-------------- 1 file changed, 113 insertions(+), 113 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index e0b4b75439927..2a71d8c275677 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1650,8 +1650,8 @@ void mul_vec_q_n_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -1659,7 +1659,7 @@ void mul_vec_q_n_f32_impl( // pointers to src0 rows device const block_q_type * ax[nr]; for (int row = 0; row < nr; ++row) { - const uint offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } @@ -1667,10 +1667,10 @@ void mul_vec_q_n_f32_impl( float yl[16]; // src1 vector cache float sumf[nr] = {0.f}; - const int ix = (tiisg/2); - const int il = (tiisg%2)*8; + const short ix = (tiisg/2); + const short il = (tiisg%2)*8; - device const float * yb = y + ix * QK4_0 + il; + device const float * yb = y + ix*QK4_0 + il; // each thread in a SIMD group deals with half a block. for (int ib = ix; ib < nb; ib += nw/2) { @@ -1776,8 +1776,8 @@ void kernel_mul_mv_q8_0_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -1785,7 +1785,7 @@ void kernel_mul_mv_q8_0_f32_impl( // pointers to src0 rows device const block_q8_0 * ax[nr]; for (int row = 0; row < nr; ++row) { - const uint offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } @@ -1793,21 +1793,21 @@ void kernel_mul_mv_q8_0_f32_impl( float yl[NB_Q8_0]; float sumf[nr] = { 0.f }; - const int ix = tiisg/4; - const int il = tiisg%4; + const short ix = tiisg/4; + const short il = tiisg%4; device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; // each thread in a SIMD group deals with NB_Q8_0 quants at a time for (int ib = ix; ib < nb; ib += nw/4) { - for (int i = 0; i < NB_Q8_0; ++i) { + for (short i = 0; i < NB_Q8_0; ++i) { yl[i] = yb[i]; } for (int row = 0; row < nr; row++) { device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; float sumq = 0.f; - for (int iq = 0; iq < NB_Q8_0; ++iq) { + for (short iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; } sumf[row] += sumq*ax[row][ib].d; @@ -1816,7 +1816,7 @@ void kernel_mul_mv_q8_0_f32_impl( yb += nw*NB_Q8_0; } - device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); @@ -1849,18 +1849,18 @@ void kernel_mul_mv_impl( device char * dst, uint3 tgpig, ushort tiisg) { - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_MV_T_T; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int rb = tgpig.y*N_MV_T_T; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; device const T0 * x = (device const T0 *) (src0 + offset0); - device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; if (args.ne00 < 128) { for (int row = 0; row < N_MV_T_T; ++row) { @@ -1869,7 +1869,7 @@ void kernel_mul_mv_impl( break; } - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); @@ -1880,7 +1880,7 @@ void kernel_mul_mv_impl( float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst_f32[r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } else { @@ -1891,20 +1891,20 @@ void kernel_mul_mv_impl( break; } - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); device const T14 * y4 = (device const T14 *) y; float sumf = 0; for (int i = tiisg; i < args.ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + sumf += dot((T14) x4[i], y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst_f32[r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } @@ -1940,25 +1940,27 @@ template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv< template kernel void kernel_mul_mv_1row( constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]]) { - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T * x = (device const T *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + float sumf = 0; if (args.ne00 < 128) { for (int i = tiisg; i < args.ne00; i += 32) { @@ -1966,21 +1968,21 @@ kernel void kernel_mul_mv_1row( } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; + dst_f32[r0] = all_sum; } } else { device const T4 * x4 = (device const T4 *) x; device const float4 * y4 = (device const float4 *) y; for (int i = tiisg; i < args.ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + sumf += dot((float4) x4[i], y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; + dst_f32[r0] = all_sum; } } } @@ -1996,36 +1998,38 @@ template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kerne template kernel void kernel_mul_mv_l4( constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]]) { const int nrows = args.ne11; - const int64_t r0 = tgpig.x; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; device const T4 * x4 = (device const T4 *) (src0 + offset0); + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + for (int r1 = 0; r1 < nrows; ++r1) { - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const float4 * y4 = (device const float4 *) (src1 + offset1); float sumf = 0; for (int i = tiisg; i < args.ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + sumf += dot((float4) x4[i], y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } @@ -2974,7 +2978,7 @@ kernel void kernel_flash_attn_ext( const float S = ss[j*TS + 0]; for (short i = tiisg; i < D4; i += NW) { - dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; } } } @@ -3366,7 +3370,7 @@ kernel void kernel_flash_attn_ext_vec( const float S = ss[0]; for (short i = tiisg; i < D16; i += NW) { - dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; + dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; } } } @@ -3961,8 +3965,8 @@ void kernel_mul_mv_q2_K_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4022,7 +4026,7 @@ void kernel_mul_mv_q2_K_f32_impl( y4 += 4 * QK_K; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4058,17 +4062,17 @@ void kernel_mul_mv_q3_K_f32_impl( const int nb = args.ne00/QK_K; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); @@ -4101,9 +4105,10 @@ void kernel_mul_mv_q3_K_f32_impl( const ushort4 hm = mm[2*ip + il/2]; - const int shift = 2*il; - const float v1 = il == 0 ? 4.f : 64.f; - const float v2 = 4.f * v1; + const short shift = 2*il; + + const float v1 = il == 0 ? 4.f : 64.f; + const float v2 = 4.f * v1; const uint16_t s_shift1 = 4*ip; const uint16_t s_shift2 = s_shift1 + il; @@ -4186,7 +4191,7 @@ void kernel_mul_mv_q3_K_f32_impl( sumf1[row] = simd_sum(sumf); } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; if (tiisg == 0) { for (int row = 0; row < 2; ++row) { @@ -4238,8 +4243,8 @@ void kernel_mul_mv_q4_K_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4303,7 +4308,7 @@ void kernel_mul_mv_q4_K_f32_impl( y4 += 4 * QK_K; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4339,8 +4344,8 @@ void kernel_mul_mv_q5_K_f32_impl( const int nb = args.ne00/QK_K; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; + const int r0 = tgpig.x; + const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; @@ -4348,8 +4353,8 @@ void kernel_mul_mv_q5_K_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); @@ -4435,7 +4440,7 @@ void kernel_mul_mv_q5_K_f32_impl( y1 += 4 * QK_K; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < 2; ++row) { const float tot = simd_sum(sumf[row]); @@ -4476,17 +4481,17 @@ void kernel_mul_mv_q6_K_f32_impl( const int nb = args.ne00/QK_K; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int im = tgpig.z; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; - const int row = 2 * r0 + sgitg; + const int row = 2*r0 + sgitg; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); @@ -4506,7 +4511,6 @@ void kernel_mul_mv_q6_K_f32_impl( const int q_offset_h = 32*ip + l0; for (int i = ix; i < nb; i += 2) { - device const uint8_t * q1 = x[i].ql + q_offset_l; device const uint8_t * q2 = q1 + 32; device const uint8_t * qh = x[i].qh + q_offset_h; @@ -4528,7 +4532,7 @@ void kernel_mul_mv_q6_K_f32_impl( } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; const float tot = simd_sum(sumf); if (tiisg == 0) { @@ -4572,8 +4576,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4636,7 +4640,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4680,8 +4684,8 @@ void kernel_mul_mv_iq2_xs_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4754,7 +4758,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4799,8 +4803,8 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4827,7 +4831,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - for (int i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -4841,7 +4844,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const half * dh = &xr->d; for (int row = 0; row < N_DST; row++) { - const float db = dh[0]; const uint32_t aux32 = gas[0] | (gas[1] << 16); const float d = db * (0.5f + (aux32 >> 28)); @@ -4866,7 +4868,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4911,8 +4913,8 @@ void kernel_mul_mv_iq3_s_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4978,7 +4980,7 @@ void kernel_mul_mv_iq3_s_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5023,8 +5025,8 @@ void kernel_mul_mv_iq2_s_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5091,7 +5093,7 @@ void kernel_mul_mv_iq2_s_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5136,8 +5138,8 @@ void kernel_mul_mv_iq1_s_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5191,7 +5193,7 @@ void kernel_mul_mv_iq1_s_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5222,8 +5224,8 @@ void kernel_mul_mv_iq1_m_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5286,7 +5288,7 @@ void kernel_mul_mv_iq1_m_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5317,8 +5319,8 @@ void kernel_mul_mv_iq4_nl_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5376,7 +5378,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( yb += 16 * QK4_NL; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { all_sum = simd_sum(sumf[row]); @@ -5407,8 +5409,8 @@ void kernel_mul_mv_iq4_xs_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5432,25 +5434,23 @@ void kernel_mul_mv_iq4_xs_f32_impl( float4 qf1, qf2; for (int ibl = ix; ibl < nb; ibl += 2) { - device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; for (int row = 0; row < 2; ++row) { - device const block_iq4_xs & xb = x[row*nb + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); float4 acc1 = {0.f}, acc2 = {0.f}; - aux32[0] = q4[0] & 0x0f0f0f0f; + aux32[0] = (q4[0] ) & 0x0f0f0f0f; aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[0] * qf1; acc2 += yl[1] * qf2; - aux32[0] = q4[1] & 0x0f0f0f0f; + aux32[0] = (q4[1] ) & 0x0f0f0f0f; aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; @@ -5467,7 +5467,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( yb += 2 * QK_K; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < 2; ++row) { all_sum = simd_sum(sumf[row]); @@ -5670,8 +5670,8 @@ kernel void kernel_mul_mm( const int i12 = im%args.ne12; const int i13 = im/args.ne12; - int offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - short offset1 = il/nl; + uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + short offset1 = il/nl; device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1; device const float * y = (device const float *)(src1 @@ -5796,10 +5796,10 @@ void kernel_mul_mm_id_impl( threadgroup half * sa = (threadgroup half *)(shmem); threadgroup float * sb = (threadgroup float *)(shmem + 4096); - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; + const int r0 = tgpig.y; + const int r1 = tgpig.x; - if (r1 * BLOCK_SIZE_N >= ne1) return; + if (r1*BLOCK_SIZE_N >= ne1) return; // if this block is of 64x32 shape or smaller short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; @@ -5930,7 +5930,7 @@ kernel void kernel_mul_mm_id( threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192); // TODO: parallelize this loop - int64_t _ne1 = 0; + int32_t _ne1 = 0; for (ushort ii1 = 0; ii1 < args.nei1; ii1++) { for (ushort ii0 = 0; ii0 < args.nei0; ii0++) { int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0]; From 4c1c7213e2e3d76b080537dfbe8642399b141edc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 13:03:25 +0200 Subject: [PATCH 13/20] metal : GGML_OP_CONCAT ggml-ci --- ggml/src/ggml-common.h | 28 +++++++++++++ ggml/src/ggml-metal/ggml-metal.m | 60 +++++++++++++++------------- ggml/src/ggml-metal/ggml-metal.metal | 54 +++++++------------------ 3 files changed, 75 insertions(+), 67 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index a207032693539..cd54c18ed0956 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -419,6 +419,34 @@ typedef struct { static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); #if defined(GGML_COMMON_DECL_METAL_KARGS) +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t dim; +} ggml_metal_kargs_concat; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index fb697b7f77bb6..9f7810126f5e7 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1197,35 +1197,39 @@ static void ggml_metal_encode_node( const int32_t dim = ((const int32_t *) dst->op_params)[0]; + ggml_metal_kargs_concat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.dim =*/ dim, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; const int nth = MIN(1024, ne0); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2a71d8c275677..d8d44bdb20c49 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1898,7 +1898,7 @@ void kernel_mul_mv_impl( float sumf = 0; for (int i = tiisg; i < args.ne00/4; i += 32) { - sumf += dot((T14) x4[i], y4[i]); + sumf += dot((float4) x4[i], (float4) y4[i]); } float all_sum = simd_sum(sumf); @@ -3890,55 +3890,31 @@ kernel void kernel_cpy_f32_iq4_nl( } kernel void kernel_concat( + constant ggml_metal_kargs_concat & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int32_t & dim, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; - int64_t o[4] = {0, 0, 0, 0}; - o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); + int o[4] = {0, 0, 0, 0}; + o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03)); device const float * x; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + x = (device const float *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00); } else { - x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); + x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10); } - device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); *y = *x; } From 281fa05e831c9ea3e1ed5c5322555aaf959ffd59 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 13:16:54 +0200 Subject: [PATCH 14/20] metal : GGML_OP_ADD, GGML_OP_SUB, GGML_OP_MUL, GGML_OP_DIV --- ggml/src/ggml-common.h | 28 ++++ ggml/src/ggml-metal/ggml-metal.m | 124 +++++++------- ggml/src/ggml-metal/ggml-metal.metal | 234 +++++++++------------------ 3 files changed, 164 insertions(+), 222 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index cd54c18ed0956..6d7e07ee6c20b 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -447,6 +447,34 @@ typedef struct { int32_t dim; } ggml_metal_kargs_concat; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + uint64_t offs; +} ggml_metal_kargs_bin; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 9f7810126f5e7..1dcf79754f599 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1247,8 +1247,6 @@ static void ggml_metal_encode_node( bool bcast_row = false; - int64_t nb = ne00; // used by the "row" kernels - id pipeline = nil; if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { @@ -1257,7 +1255,6 @@ static void ggml_metal_encode_node( // src1 is a row GGML_ASSERT(ne11 == 1); - nb = ne00 / 4; switch (dst->op) { case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break; @@ -1277,36 +1274,39 @@ static void ggml_metal_encode_node( } } + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.offs =*/ offs, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; if (bcast_row) { const int64_t n = ggml_nelements(dst)/4; @@ -1404,35 +1404,39 @@ static void ggml_metal_encode_node( const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ pnb1, + /*.nb02 =*/ pnb2, + /*.nb03 =*/ pnb3, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ pnb1, + /*.nb2 =*/ pnb2, + /*.nb3 =*/ pnb3, + /*.offs =*/ offs, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index d8d44bdb20c49..49342b88c5c79 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -498,200 +498,106 @@ enum ggml_sort_order { // pros: works for non-contiguous tensors, supports broadcast across all dims // cons: not very efficient kernel void kernel_add( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int64_t & offs, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10)); } } kernel void kernel_sub( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int64_t & offs, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); } } kernel void kernel_mul( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); } } kernel void kernel_div( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); } } @@ -745,38 +651,42 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat // assumption: src1 is a row // broadcast src1 into src0 kernel void kernel_add_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] + src1[tpig % nb]; } kernel void kernel_sub_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] - src1[tpig % nb]; } kernel void kernel_mul_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] * src1[tpig % nb]; } kernel void kernel_div_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] / src1[tpig % nb]; } From d7488ba09c33a4c686478745f8f4a9f6a9e3f1ed Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 13:21:59 +0200 Subject: [PATCH 15/20] metal : GGML_OP_REPEAT --- ggml/src/ggml-common.h | 19 ++++++++++++ ggml/src/ggml-metal/ggml-metal.m | 40 ++++++++++++++----------- ggml/src/ggml-metal/ggml-metal.metal | 45 ++++++++++------------------ 3 files changed, 56 insertions(+), 48 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 6d7e07ee6c20b..a0720e9828010 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -475,6 +475,25 @@ typedef struct { uint64_t offs; } ggml_metal_kargs_bin; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_repeat; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 1dcf79754f599..664457fd2ad4a 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1330,25 +1330,29 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); } + ggml_metal_kargs_repeat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 49342b88c5c79..915119b0a1eab 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -603,41 +603,26 @@ kernel void kernel_div( template kernel void kernel_repeat( + constant ggml_metal_kargs_repeat & args, device const char * src0, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; - const int64_t i03 = i3 % ne03; - const int64_t i02 = i2 % ne02; - const int64_t i01 = i1 % ne01; + const int i03 = i3%args.ne03; + const int i02 = i2%args.ne02; + const int i01 = i1%args.ne01; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i00 = i0 % ne00; - *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i00 = i0%args.ne00; + *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00)); } } From 2b86f84839c18e81b3436baa4f298ab63ac3fefb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 13:55:26 +0200 Subject: [PATCH 16/20] metal : GGML_OP_CPY --- ggml/src/ggml-common.h | 19 ++ ggml/src/ggml-metal/ggml-metal.m | 80 ++++--- ggml/src/ggml-metal/ggml-metal.metal | 344 +++++++++------------------ 3 files changed, 182 insertions(+), 261 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index a0720e9828010..2d45311a37e37 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -494,6 +494,25 @@ typedef struct { uint64_t nb3; } ggml_metal_kargs_repeat; +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_cpy; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 664457fd2ad4a..1178434ebf37a 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1381,25 +1381,29 @@ static void ggml_metal_encode_node( const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; + ggml_metal_kargs_cpy args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); @@ -3429,25 +3433,29 @@ static void ggml_metal_encode_node( default: GGML_ABORT("not implemented"); } + ggml_metal_kargs_cpy args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 915119b0a1eab..e3c827316aa80 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3307,42 +3307,27 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ template kernel void kernel_cpy( - device const void * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); - device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { + device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; } } @@ -3362,42 +3347,27 @@ template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy 0 ? sumqx/sumq2 : d; - } } From b438ff7e7c301483ac81f161e66caa334046d552 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 15:31:43 +0200 Subject: [PATCH 17/20] metal : GGML_OP_RMS_NORM --- ggml/src/ggml-common.h | 7 ++++ ggml/src/ggml-metal/ggml-metal.m | 22 ++++++---- ggml/src/ggml-metal/ggml-metal.metal | 63 +++++++++++++--------------- 3 files changed, 51 insertions(+), 41 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 2d45311a37e37..86f4753b17218 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -642,6 +642,13 @@ typedef struct { int32_t ne1; uint64_t nb1; } ggml_metal_kargs_mul_mv_id; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_rms_norm; #endif #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 1178434ebf37a..aa202438324b4 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2622,20 +2622,28 @@ static void ggml_metal_encode_node( float eps; memcpy(&eps, dst->op_params, sizeof(float)); + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + int nth = 32; // SIMD width - while (nth < ne00/4 && nth < 1024) { + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { nth *= 2; } - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + nth = MIN(nth, ne00/4); + + ggml_metal_kargs_rms_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; const int64_t nrows = ggml_nrows(src0); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index e3c827316aa80..45a205eacaaf7 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1293,50 +1293,45 @@ kernel void kernel_norm( } kernel void kernel_rms_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + constant ggml_metal_kargs_rms_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; + } - float4 sumf = 0; - float all_sum = 0; + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + + float sumf = 0.0f; // parallel sum - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - sumf += x[i00] * x[i00]; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf += dot(x[i00], x[i00]); } - all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; - all_sum = simd_sum(all_sum); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } + sumf = simd_sum(sumf); - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - if (tiisg == 0) { - buf[sgitg] = all_sum; - } + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - all_sum = buf[tiisg]; - all_sum = simd_sum(all_sum); - } + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); - const float mean = all_sum/ne00; - const float scale = 1.0f/sqrt(mean + eps); + const float mean = sumf/args.ne00; + const float scale = 1.0f/sqrt(mean + args.eps); - device float4 * y = (device float4 *) (dst + tgpig*ne00); - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { y[i00] = x[i00] * scale; } } From f018669cf5cf7eca08959d826f0a675ce6a4ed48 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 17:17:18 +0200 Subject: [PATCH 18/20] metal : GGML_OP_NORM --- ggml/src/ggml-common.h | 7 +++ ggml/src/ggml-metal/ggml-metal.m | 29 +++++++--- ggml/src/ggml-metal/ggml-metal.metal | 87 ++++++++++++++++------------ 3 files changed, 79 insertions(+), 44 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 86f4753b17218..d25100693bdda 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -643,6 +643,13 @@ typedef struct { uint64_t nb1; } ggml_metal_kargs_mul_mv_id; +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_norm; + typedef struct { int32_t ne00; int32_t ne00_4; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index aa202438324b4..28a6092268758 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2685,22 +2685,35 @@ static void ggml_metal_encode_node( } break; case GGML_OP_NORM: { + GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ggml_is_contiguous_1(src0)); float eps; memcpy(&eps, dst->op_params, sizeof(float)); - const int nth = MIN(256, ne00); - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + ggml_metal_kargs_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; const int64_t nrows = ggml_nrows(src0); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 45a205eacaaf7..6bdf4e4cc3e09 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1241,53 +1241,68 @@ kernel void kernel_ssm_scan_f32( } kernel void kernel_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * sum [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); - // MEAN - // parallel sum - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - sum[tpitg] += x[i00]; + constant ggml_metal_kargs_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; } - // reduce + + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + + float4 sumf4(0.0f); + + float sumf = 0.0f; + + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf4 += x[i00]; + } + sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3]; + sumf = simd_sum(sumf); + threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; } - const float mean = sum[0] / ne00; - // recenter and VARIANCE threadgroup_barrier(mem_flags::mem_threadgroup); - device float * y = dst + tgpig*ne00; - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float mean = sumf/args.ne00; + + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + + sumf = 0.0f; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { y[i00] = x[i00] - mean; - sum[tpitg] += y[i00] * y[i00]; + sumf += dot(y[i00], y[i00]); } + sumf = simd_sum(sumf); - // reduce threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; } - const float variance = sum[0] / ne00; - const float scale = 1.0f/sqrt(variance + eps); - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float variance = sumf/args.ne00; + + const float scale = 1.0f/sqrt(variance + args.eps); + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { y[i00] = y[i00] * scale; } } From 1c603023ed1bb3fb28e3fede683e224a22388105 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 17:56:12 +0200 Subject: [PATCH 19/20] metal : add TODOs for rest of ops --- ggml/src/ggml-metal/ggml-metal.m | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 28a6092268758..b683d54318af4 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1485,10 +1485,10 @@ static void ggml_metal_encode_node( memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float)); [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&min length:sizeof(min) atIndex:2]; - [encoder setBytes:&max length:sizeof(max) atIndex:3]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&min length:sizeof(min) atIndex:2]; + [encoder setBytes:&max length:sizeof(max) atIndex:3]; const int64_t n = ggml_nelements(dst); @@ -1660,6 +1660,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1735,6 +1736,8 @@ static void ggml_metal_encode_node( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + // TODO: add ggml_metal_kargs struct + // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6) [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; if (id_src1) { @@ -1751,6 +1754,7 @@ static void ggml_metal_encode_node( [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; @@ -1767,6 +1771,7 @@ static void ggml_metal_encode_node( pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; } + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1791,6 +1796,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -1861,6 +1867,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -2599,6 +2606,7 @@ static void ggml_metal_encode_node( default: GGML_ABORT("not implemented"); } + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -2668,6 +2676,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2857,6 +2866,7 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); }; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2897,6 +2907,7 @@ static void ggml_metal_encode_node( const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2931,6 +2942,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2967,6 +2979,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; @@ -2988,6 +3001,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -3026,6 +3040,7 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); }; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -3044,6 +3059,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -3521,6 +3537,7 @@ static void ggml_metal_encode_node( const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements); const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; From a112eb45c4584328c4a47f00e3369ae309147b64 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 18:29:09 +0200 Subject: [PATCH 20/20] ggml : add ggml-metal-impl.h ggml-ci --- Makefile | 5 +- ggml/src/ggml-common.h | 240 ------------------------- ggml/src/ggml-metal/CMakeLists.txt | 18 +- ggml/src/ggml-metal/ggml-metal-impl.h | 249 ++++++++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal.m | 5 +- ggml/src/ggml-metal/ggml-metal.metal | 2 +- 6 files changed, 266 insertions(+), 253 deletions(-) create mode 100644 ggml/src/ggml-metal/ggml-metal-impl.h diff --git a/Makefile b/Makefile index fecf1f693e195..647da232b1f9e 100644 --- a/Makefile +++ b/Makefile @@ -963,6 +963,7 @@ endif # GGML_METAL ifdef GGML_METAL ggml/src/ggml-metal/ggml-metal.o: \ ggml/src/ggml-metal/ggml-metal.m \ + ggml/src/ggml-metal/ggml-metal-impl.h \ ggml/include/ggml-metal.h \ ggml/include/ggml.h $(CC) $(CFLAGS) -c $< -o $@ @@ -970,9 +971,11 @@ ggml/src/ggml-metal/ggml-metal.o: \ ifdef GGML_METAL_EMBED_LIBRARY ggml/src/ggml-metal-embed.o: \ ggml/src/ggml-metal/ggml-metal.metal \ + ggml/src/ggml-metal/ggml-metal-impl.h \ ggml/src/ggml-common.h @echo "Embedding Metal library" - @sed -e '/__embed_ggml-common.h__/r ggml/src/ggml-common.h' -e '/__embed_ggml-common.h__/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal + @sed -e '/__embed_ggml-common.h__/r ggml/src/ggml-common.h' -e '/__embed_ggml-common.h__/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal.tmp + @sed -e '/#include "ggml-metal-impl.h"/r ggml/src/ggml-metal/ggml-metal-impl.h' -e '/#include "ggml-metal-impl.h"/d' < ggml/src/ggml-metal/ggml-metal-embed.metal.tmp > ggml/src/ggml-metal/ggml-metal-embed.metal $(eval TEMP_ASSEMBLY=$(shell mktemp -d)) @echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)/ggml-metal-embed.s @echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index d25100693bdda..050161393456e 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -418,246 +418,6 @@ typedef struct { } block_iq4_xs; static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); -#if defined(GGML_COMMON_DECL_METAL_KARGS) -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne10; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; - int32_t ne0; - int32_t ne1; - int32_t ne2; - int32_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; - int32_t dim; -} ggml_metal_kargs_concat; - -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne10; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; - int32_t ne0; - int32_t ne1; - int32_t ne2; - int32_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; - uint64_t offs; -} ggml_metal_kargs_bin; - -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne0; - int32_t ne1; - int32_t ne2; - int32_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; -} ggml_metal_kargs_repeat; - -typedef struct { - int64_t ne00; - int64_t ne01; - int64_t ne02; - int64_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int64_t ne0; - int64_t ne1; - int64_t ne2; - int64_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; -} ggml_metal_kargs_cpy; - -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne0; - int32_t ne1; - int32_t ne2; - int32_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; - int32_t n_past; - int32_t n_dims; - int32_t n_ctx_orig; - float freq_base; - float freq_scale; - float ext_factor; - float attn_factor; - float beta_fast; - float beta_slow; -} ggml_metal_kargs_rope; - -typedef struct { - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne11; - int32_t ne_12_2; // assume K and V are same shape - int32_t ne_12_3; - uint64_t nb_12_1; - uint64_t nb_12_2; - uint64_t nb_12_3; - uint64_t nb31; - int32_t ne1; - int32_t ne2; - float scale; - float max_bias; - float m0; - float m1; - uint16_t n_head_log2; - float logit_softcap; -} ggml_metal_kargs_flash_attn_ext; - -typedef struct { - int32_t ne00; - int32_t ne02; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne12; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; - int32_t ne0; - int32_t ne1; - int16_t r2; - int16_t r3; -} ggml_metal_kargs_mul_mm; - -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne10; - int32_t ne11; - int32_t ne12; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; - int32_t ne0; - int32_t ne1; - int16_t r2; - int16_t r3; -} ggml_metal_kargs_mul_mv; - -typedef struct { - int32_t nei0; - int32_t nei1; - uint64_t nbi1; - int32_t ne00; - int32_t ne02; - uint64_t nb01; - uint64_t nb02; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - int32_t ne0; - int32_t ne1; -} ggml_metal_kargs_mul_mm_id; - -typedef struct { - int32_t nei0; - int32_t nei1; - uint64_t nbi1; - int32_t ne00; - int32_t ne01; - int32_t ne02; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - int32_t ne10; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - int32_t ne0; - int32_t ne1; - uint64_t nb1; -} ggml_metal_kargs_mul_mv_id; - -typedef struct { - int32_t ne00; - int32_t ne00_4; - uint64_t nb01; - float eps; -} ggml_metal_kargs_norm; - -typedef struct { - int32_t ne00; - int32_t ne00_4; - uint64_t nb01; - float eps; -} ggml_metal_kargs_rms_norm; -#endif - #endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-metal/CMakeLists.txt b/ggml/src/ggml-metal/CMakeLists.txt index e0992c7449b62..b237d79f47ddb 100644 --- a/ggml/src/ggml-metal/CMakeLists.txt +++ b/ggml/src/ggml-metal/CMakeLists.txt @@ -25,9 +25,10 @@ if (GGML_METAL_USE_BF16) add_compile_definitions(GGML_METAL_USE_BF16) endif() -# copy ggml-common.h and ggml-metal.metal to bin directory -configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) -configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) +# copy metal files to bin directory +configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) +configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) +configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY) if (GGML_METAL_EMBED_LIBRARY) enable_language(ASM) @@ -36,24 +37,27 @@ if (GGML_METAL_EMBED_LIBRARY) set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h") set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") + set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h") file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated") # merge ggml-common.h and ggml-metal.metal into a single file - set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") - set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") + set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") + set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") + set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp") add_custom_command( OUTPUT ${METALLIB_EMBED_ASM} COMMAND echo "Embedding Metal library" - COMMAND sed -e '/__embed_ggml-common.h__/r ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED} + COMMAND sed -e '/__embed_ggml-common.h__/r ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED_TMP} + COMMAND sed -e '/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}' -e '/\#include \"ggml-metal-impl.h\"/d' < ${METALLIB_SOURCE_EMBED_TMP} > ${METALLIB_SOURCE_EMBED} COMMAND echo ".section __DATA,__ggml_metallib" > ${METALLIB_EMBED_ASM} COMMAND echo ".globl _ggml_metallib_start" >> ${METALLIB_EMBED_ASM} COMMAND echo "_ggml_metallib_start:" >> ${METALLIB_EMBED_ASM} COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM} COMMAND echo ".globl _ggml_metallib_end" >> ${METALLIB_EMBED_ASM} COMMAND echo "_ggml_metallib_end:" >> ${METALLIB_EMBED_ASM} - DEPENDS ggml-metal.metal ../ggml-common.h + DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h COMMENT "Generate assembly for embedded Metal library" ) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h new file mode 100644 index 0000000000000..53c13549650c8 --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -0,0 +1,249 @@ +#ifndef GGML_METAL_IMPL +#define GGML_METAL_IMPL + +// kernel argument structs +// +// - element counters (e.g. ne00) typically use int32_t to reduce register usage +// however, be careful from int overflows when using those in the kernel implementation +// +// - strides (e.g. nb00) use uint64_t + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t dim; +} ggml_metal_kargs_concat; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + uint64_t offs; +} ggml_metal_kargs_bin; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_repeat; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_cpy; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t n_past; + int32_t n_dims; + int32_t n_ctx_orig; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; +} ggml_metal_kargs_rope; + +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb_12_1; + uint64_t nb_12_2; + uint64_t nb_12_3; + uint64_t nb31; + int32_t ne1; + int32_t ne2; + float scale; + float max_bias; + float m0; + float m1; + uint16_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext; + +typedef struct { + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mm; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mv; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; +} ggml_metal_kargs_mul_mm_id; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; + uint64_t nb1; +} ggml_metal_kargs_mul_mv_id; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_norm; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_rms_norm; + +#endif // GGML_METAL_IMPL diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index b683d54318af4..58fee4bfd1296 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2,10 +2,7 @@ #import "ggml-impl.h" #import "ggml-backend-impl.h" - -#define GGML_COMMON_DECL_C -#define GGML_COMMON_DECL_METAL_KARGS -#include "ggml-common.h" +#import "ggml-metal-impl.h" #import diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 6bdf4e4cc3e09..86fdf1c18cfb6 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1,5 +1,4 @@ #define GGML_COMMON_DECL_METAL -#define GGML_COMMON_DECL_METAL_KARGS #define GGML_COMMON_IMPL_METAL #if defined(GGML_METAL_EMBED_LIBRARY) __embed_ggml-common.h__ @@ -7,6 +6,7 @@ __embed_ggml-common.h__ // TODO: this should not be a relative path, but can't figure out how to set Metal include paths in Package.swift #include "../ggml-common.h" #endif +#include "ggml-metal-impl.h" #include