diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 0bf7fe9f9237a..819f31c8a300c 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -495,22 +495,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_ case GGML_TYPE_F16: case GGML_TYPE_BF16: { - if (ne00 == 4) { + if (ne00 < 32) { nsg = 1; nr0 = 32; - nr1 = 4; - suffix = "_c4"; - } else if (ne00 % 4 == 0) { - nsg = N_SG_F; - nr0 = N_R0_F; nr1 = 1; - smem = 32*sizeof(float)*N_R0_F; - suffix = "_4"; + suffix = "_short"; } else { - nsg = N_SG_F; - nr0 = N_R0_F; + nsg = std::min(4, (ne00 + 127) / 128); + nr0 = 2; nr1 = 1; - smem = 32*sizeof(float)*N_R0_F; + smem = 32*sizeof(float)*nr0; + suffix = ne00 % 4 == 0 ? "_4" : ""; } } break; case GGML_TYPE_Q4_0: @@ -727,18 +722,11 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra case GGML_TYPE_F16: case GGML_TYPE_BF16: { - if (ne00 % 4 == 0) { - nsg = N_SG_F; - nr0 = N_R0_F; - nr1 = 1; - smem = 32*sizeof(float)*N_R0_F; - suffix = "_4"; - } else { - nsg = N_SG_F; - nr0 = N_R0_F; - nr1 = 1; - smem = 32*sizeof(float)*N_R0_F; - } + nsg = std::min(4, (ne00 + 127) / 128); + nr0 = 2; + nr1 = 1; + smem = 32*sizeof(float)*nr0; + suffix = ne00 % 4 == 0 ? "_4" : ""; } break; case GGML_TYPE_Q4_0: { diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index d355c6dfc7526..88c98423ebec0 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -8,9 +8,6 @@ // // TODO: for optimal performance, become function of the device and work size -#define N_R0_F 2 -#define N_SG_F 4 - #define N_R0_Q4_0 4 #define N_SG_Q4_0 2 @@ -352,6 +349,7 @@ typedef struct { uint64_t nb13; int32_t ne0; int32_t ne1; + int32_t nr0; int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mv; @@ -427,6 +425,7 @@ typedef struct { int32_t ne0; int32_t ne1; uint64_t nb1; + int32_t nr0; } ggml_metal_kargs_mul_mv_id; // NORM diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index d7267a6aedfff..e85a223c01dc3 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1565,6 +1565,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { } else { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); + const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); + const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); + const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + ggml_metal_kargs_mul_mv args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, @@ -1582,16 +1588,11 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { /*.nb13 =*/ nb13, /*.ne0 =*/ ne0, /*.ne1 =*/ ne1, + /*.nr0 =*/ nr0, /*.r2 =*/ r2, /*.r3 =*/ r3, }; - const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); - const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); - const int nsg = ggml_metal_pipeline_get_nsg(pipeline); - - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); - ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); @@ -1758,6 +1759,14 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1); } } else { + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); + + const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); + const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); + const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + ggml_metal_kargs_mul_mv_id args = { /*.nei0 =*/ ne20, /*.nei1 =*/ ne21, @@ -1778,16 +1787,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { /*.ne0 =*/ ne0, /*.ne1 =*/ ne1, /*.nb1 =*/ nb1, + /*.nr0 =*/ nr0, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); - - const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); - const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); - const int nsg = ggml_metal_pipeline_get_nsg(pipeline); - - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); - if (ggml_is_quantized(op->src[0]->type)) { GGML_ASSERT(ne00 >= nsg*nr0); } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 0271fd5b25d73..96df6f0ce62de 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3531,7 +3531,25 @@ void kernel_mul_mv_t_t_impl( helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); } -template +template +void kernel_mul_mv_t_t_disp( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + switch (args.nr0) { + //case 1: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + case 2: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 3: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 4: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + } +} + +template kernel void kernel_mul_mv_t_t( constant ggml_metal_kargs_mul_mv & args, device const char * src0, @@ -3541,17 +3559,17 @@ kernel void kernel_mul_mv_t_t( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_t_t_disp(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(kernel_mul_mv_t_t) mul_mv_t_t; +typedef decltype(kernel_mul_mv_t_t) mul_mv_t_t; -template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; -template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; -template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; -template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; #endif template @@ -3637,7 +3655,25 @@ void kernel_mul_mv_t_t_4_impl( helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); } -template +template +void kernel_mul_mv_t_t_4_disp( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + switch (args.nr0) { + //case 1: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + case 2: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 3: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 4: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + }; +} + +template kernel void kernel_mul_mv_t_t_4( constant ggml_metal_kargs_mul_mv & args, device const char * src0, @@ -3647,23 +3683,21 @@ kernel void kernel_mul_mv_t_t_4( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_t_t_4_disp(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(kernel_mul_mv_t_t_4) mul_mv_t_t_4; +typedef decltype(kernel_mul_mv_t_t_4) mul_mv_t_t_4; -template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; -template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; -template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; -template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; #endif -#define N_MV_T_T 4 - -template -void kernel_mul_mv_c4_impl( +template +void kernel_mul_mv_t_t_short_impl( args_t args, device const char * src0, device const char * src1, @@ -3671,7 +3705,7 @@ void kernel_mul_mv_c4_impl( uint3 tgpig, ushort tiisg) { const int r0 = tgpig.x*32 + tiisg; - const int rb = tgpig.y*N_MV_T_T; + const int r1 = tgpig.y; const int im = tgpig.z; if (r0 >= args.ne01) { @@ -3683,33 +3717,32 @@ void kernel_mul_mv_c4_impl( const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - device const T04 * x = (device const T04 *) (src0 + offset0); + device const T0 * x = (device const T0 *) (src0 + offset0); device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; - for (int row = 0; row < N_MV_T_T; ++row) { - int r1 = rb + row; - if (r1 >= args.ne11) { - break; - } + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + device const T1 * y = (device const T1 *) (src1 + offset1); - device const T14 * y = (device const T14 *) (src1 + offset1); + float res = 0.0f; - dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]); + for (int i = 0; i < args.ne00; ++i) { + res += (float) x[i] * (float) y[i]; } + + dst_f32[(uint64_t)r1*args.ne0 + r0] = res; } -template -kernel void kernel_mul_mv_c4( +template +kernel void kernel_mul_mv_t_t_short( constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_c4_impl( + kernel_mul_mv_t_t_short_impl( args, src0, src1, @@ -3718,14 +3751,14 @@ kernel void kernel_mul_mv_c4( tiisg); } -typedef decltype(kernel_mul_mv_c4) mul_mv_c4_t; +typedef decltype(kernel_mul_mv_t_t_short) mul_mv_t_t_short_t; -template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -template [[host_name("kernel_mul_mv_f16_f16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; +template [[host_name("kernel_mul_mv_f32_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_f16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_f16_f16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -template [[host_name("kernel_mul_mv_bf16_bf16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; +template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; #endif static float rope_yarn_ramp(const float low, const float high, const int i0) { @@ -8458,7 +8491,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_m // matrix-vector multiplication // -typedef void (kernel_mul_mv_impl_t)( +typedef void (kernel_mul_mv_disp_t)( ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, @@ -8466,7 +8499,7 @@ typedef void (kernel_mul_mv_impl_t)( uint3 tgpig, ushort tiisg); -typedef void (kernel_mul_mv2_impl_t)( +typedef void (kernel_mul_mv2_disp_t)( ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, @@ -8476,7 +8509,7 @@ typedef void (kernel_mul_mv2_impl_t)( ushort tiisg, ushort sgitg); -template +template void mmv_fn( ggml_metal_kargs_mul_mv args, device const char * src0, @@ -8487,10 +8520,10 @@ void mmv_fn( ushort tiitg, ushort tiisg, ushort sgitg) { - impl_fn(args, src0, src1, dst, tgpig, tiisg); + disp_fn(args, src0, src1, dst, tgpig, tiisg); } -template +template void mmv_fn( ggml_metal_kargs_mul_mv args, device const char * src0, @@ -8501,12 +8534,12 @@ void mmv_fn( ushort tiitg, ushort tiisg, ushort sgitg) { - impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(mmv_fn>) mul_mv_impl_fn_t; +typedef decltype(mmv_fn>) mul_mv_disp_fn_t; -template +template kernel void kernel_mul_mv_id( constant ggml_metal_kargs_mul_mv_id & args, device const char * src0s, @@ -8553,11 +8586,12 @@ kernel void kernel_mul_mv_id( /*.nb13 =*/ args.nb12, // ne12 == 1 /*.ne0 =*/ args.ne0, /*.ne1 =*/ 1, // args.ne1, + /*.nr0 =*/ args.nr0, /*.r2 =*/ 1, /*.r3 =*/ 1, }; - impl_fn( + disp_fn( args0, /* src0 */ src0_cur, /* src1 */ src1_cur, @@ -8569,19 +8603,19 @@ kernel void kernel_mul_mv_id( sgitg); } -typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; -typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_4_t; +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_4_t; -template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; #endif -template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; #endif template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;