@@ -5730,9 +5730,9 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
57305730}
57315731
57325732template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread float4x4 &)>
5733- kernel void kernel_get_rows (
5733+ kernel void kernel_get_rows_q (
57345734 device const void * src0,
5735- device const char * src1,
5735+ device const void * src1,
57365736 device float * dst,
57375737 constant int64_t & ne00,
57385738 constant uint64_t & nb01,
@@ -5745,55 +5745,24 @@ kernel void kernel_get_rows(
57455745 uint3 tgpig[[threadgroup_position_in_grid]],
57465746 uint tiitg[[thread_index_in_threadgroup]],
57475747 uint3 tptg [[threads_per_threadgroup]]) {
5748- // const int64_t i = tgpig;
5749- // const int64_t r = ((device int32_t *) src1)[i];
5750-
57515748 const int64_t i10 = tgpig.x ;
57525749 const int64_t i11 = tgpig.y ;
57535750
5754- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0 ];
5751+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0 ];
57555752
57565753 const int64_t i02 = i11;
57575754
57585755 for (int64_t ind = tiitg; ind < ne00/16 ; ind += tptg.x ) {
57595756 float4x4 temp;
5760- dequantize_func (
5761- ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
5757+ dequantize_func (((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
57625758 *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
57635759 }
57645760}
57655761
5766- kernel void kernel_get_rows_f32 (
5767- device const void * src0,
5768- device const char * src1,
5769- device float * dst,
5770- constant int64_t & ne00,
5771- constant uint64_t & nb01,
5772- constant uint64_t & nb02,
5773- constant int64_t & ne10,
5774- constant uint64_t & nb10,
5775- constant uint64_t & nb11,
5776- constant uint64_t & nb1,
5777- constant uint64_t & nb2,
5778- uint3 tgpig[[threadgroup_position_in_grid]],
5779- uint tiitg[[thread_index_in_threadgroup]],
5780- uint3 tptg [[threads_per_threadgroup]]) {
5781- const int64_t i10 = tgpig.x ;
5782- const int64_t i11 = tgpig.y ;
5783-
5784- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0 ];
5785-
5786- const int64_t i02 = i11;
5787-
5788- for (int ind = tiitg; ind < ne00; ind += tptg.x ) {
5789- ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
5790- ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
5791- }
5792- }
5793-
5794- kernel void kernel_get_rows_f16 (
5762+ template <typename T>
5763+ kernel void kernel_get_rows_f (
57955764 device const void * src0,
5796- device const char * src1,
5765+ device const void * src1,
57975766 device float * dst,
57985767 constant int64_t & ne00,
57995768 constant uint64_t & nb01,
@@ -5809,19 +5778,19 @@ kernel void kernel_get_rows_f16(
58095778 const int64_t i10 = tgpig.x ;
58105779 const int64_t i11 = tgpig.y ;
58115780
5812- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0 ];
5781+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0 ];
58135782
58145783 const int64_t i02 = i11;
58155784
58165785 for (int ind = tiitg; ind < ne00; ind += tptg.x ) {
5817- ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
5818- ((device half *) ((device char *) src0 + r*nb01 + i02*nb02 ))[ind];
5786+ (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
5787+ (( const device T *) ((const device char *) src0 + i02*nb02 + r*nb01 ))[ind];
58195788 }
58205789}
58215790
58225791kernel void kernel_get_rows_i32 (
58235792 device const void * src0,
5824- device const char * src1,
5793+ device const void * src1,
58255794 device int32_t * dst,
58265795 constant int64_t & ne00,
58275796 constant uint64_t & nb01,
@@ -5837,13 +5806,13 @@ kernel void kernel_get_rows_i32(
58375806 const int64_t i10 = tgpig.x ;
58385807 const int64_t i11 = tgpig.y ;
58395808
5840- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0 ];
5809+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0 ];
58415810
58425811 const int64_t i02 = i11;
58435812
58445813 for (int ind = tiitg; ind < ne00; ind += tptg.x ) {
5845- ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
5846- (( device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02 ))[ind];
5814+ (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
5815+ (( const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01 ))[ind];
58475816 }
58485817}
58495818
@@ -6237,41 +6206,33 @@ kernel void kernel_mul_mm_id(
62376206// get rows
62386207//
62396208
6240- typedef void (get_rows_t )(
6241- device const void * src0,
6242- device const char * src1,
6243- device float * dst,
6244- constant int64_t & ne00,
6245- constant uint64_t & nb01,
6246- constant uint64_t & nb02,
6247- constant int64_t & ne10,
6248- constant uint64_t & nb10,
6249- constant uint64_t & nb11,
6250- constant uint64_t & nb1,
6251- constant uint64_t & nb2,
6252- uint3, uint, uint3);
6253-
6254- // template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
6255- // template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
6256- template [[host_name(" kernel_get_rows_q4_0" )]] kernel get_rows_t kernel_get_rows<block_q4_0, 2 , dequantize_q4_0>;
6257- template [[host_name(" kernel_get_rows_q4_1" )]] kernel get_rows_t kernel_get_rows<block_q4_1, 2 , dequantize_q4_1>;
6258- template [[host_name(" kernel_get_rows_q5_0" )]] kernel get_rows_t kernel_get_rows<block_q5_0, 2 , dequantize_q5_0>;
6259- template [[host_name(" kernel_get_rows_q5_1" )]] kernel get_rows_t kernel_get_rows<block_q5_1, 2 , dequantize_q5_1>;
6260- template [[host_name(" kernel_get_rows_q8_0" )]] kernel get_rows_t kernel_get_rows<block_q8_0, 2 , dequantize_q8_0>;
6261- template [[host_name(" kernel_get_rows_q2_K" )]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
6262- template [[host_name(" kernel_get_rows_q3_K" )]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
6263- template [[host_name(" kernel_get_rows_q4_K" )]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
6264- template [[host_name(" kernel_get_rows_q5_K" )]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
6265- template [[host_name(" kernel_get_rows_q6_K" )]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
6266- template [[host_name(" kernel_get_rows_iq2_xxs" )]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6267- template [[host_name(" kernel_get_rows_iq2_xs" )]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6268- template [[host_name(" kernel_get_rows_iq3_xxs" )]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6269- template [[host_name(" kernel_get_rows_iq3_s" )]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
6270- template [[host_name(" kernel_get_rows_iq2_s" )]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
6271- template [[host_name(" kernel_get_rows_iq1_s" )]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
6272- template [[host_name(" kernel_get_rows_iq1_m" )]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
6273- template [[host_name(" kernel_get_rows_iq4_nl" )]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2 , dequantize_iq4_nl>;
6274- template [[host_name(" kernel_get_rows_iq4_xs" )]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6209+ typedef decltype (kernel_get_rows_f<float >) get_rows_f_t;
6210+
6211+ template [[host_name(" kernel_get_rows_f32" )]] kernel get_rows_f_t kernel_get_rows_f<float >;
6212+ template [[host_name(" kernel_get_rows_f16" )]] kernel get_rows_f_t kernel_get_rows_f<half>;
6213+ template [[host_name(" kernel_get_rows_bf16" )]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
6214+
6215+ typedef decltype (kernel_get_rows_q<block_q4_0, 2 , dequantize_q4_0>) get_rows_q_t;
6216+
6217+ template [[host_name(" kernel_get_rows_q4_0" )]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2 , dequantize_q4_0>;
6218+ template [[host_name(" kernel_get_rows_q4_1" )]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2 , dequantize_q4_1>;
6219+ template [[host_name(" kernel_get_rows_q5_0" )]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2 , dequantize_q5_0>;
6220+ template [[host_name(" kernel_get_rows_q5_1" )]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2 , dequantize_q5_1>;
6221+ template [[host_name(" kernel_get_rows_q8_0" )]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2 , dequantize_q8_0>;
6222+ template [[host_name(" kernel_get_rows_q2_K" )]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
6223+ template [[host_name(" kernel_get_rows_q3_K" )]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
6224+ template [[host_name(" kernel_get_rows_q4_K" )]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
6225+ template [[host_name(" kernel_get_rows_q5_K" )]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
6226+ template [[host_name(" kernel_get_rows_q6_K" )]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
6227+ template [[host_name(" kernel_get_rows_iq2_xxs" )]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6228+ template [[host_name(" kernel_get_rows_iq2_xs" )]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6229+ template [[host_name(" kernel_get_rows_iq3_xxs" )]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6230+ template [[host_name(" kernel_get_rows_iq3_s" )]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
6231+ template [[host_name(" kernel_get_rows_iq2_s" )]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
6232+ template [[host_name(" kernel_get_rows_iq1_s" )]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
6233+ template [[host_name(" kernel_get_rows_iq1_m" )]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
6234+ template [[host_name(" kernel_get_rows_iq4_nl" )]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2 , dequantize_iq4_nl>;
6235+ template [[host_name(" kernel_get_rows_iq4_xs" )]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
62756236
62766237//
62776238// matrix-matrix multiplication
0 commit comments