@@ -7828,8 +7828,8 @@ kernel void kernel_mul_mm(
78287828 ushort tiitg[[thread_index_in_threadgroup]],
78297829 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
78307830
7831- threadgroup T * sa = (threadgroup T *)(shmem);
7832- threadgroup float * sb = (threadgroup float *)(shmem + 4096 );
7831+ threadgroup T * sa = (threadgroup T *)(shmem);
7832+ threadgroup half * sb = (threadgroup half *)(shmem + 4096 );
78337833
78347834 const int r0 = tgpig.y ;
78357835 const int r1 = tgpig.x ;
@@ -7844,7 +7844,7 @@ kernel void kernel_mul_mm(
78447844 const short thread_col = ((short )tiitg/THREAD_PER_COL) < n_cols ? ((short )tiitg/THREAD_PER_COL) : n_cols - 1 ;
78457845
78467846 simdgroup_T8x8 ma[4 ];
7847- simdgroup_float8x8 mb[2 ];
7847+ simdgroup_half8x8 mb[2 ];
78487848 simdgroup_float8x8 mc[8 ];
78497849
78507850 for (short i = 0 ; i < 8 ; i++){
@@ -7882,7 +7882,7 @@ kernel void kernel_mul_mm(
78827882 + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = temp_a[i/4 ][i%4 ];
78837883 }
78847884
7885- *(threadgroup float2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = (float2x4 )(*((device U2x4 *) y));
7885+ *(threadgroup half2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = (half2x4 )(*((device U2x4 *) y));
78867886
78877887 il = (il + 2 < nl) ? il + 2 : il % 2 ;
78887888 x = (il < 2 ) ? x + (2 + nl - 1 )/nl : x;
@@ -7891,8 +7891,8 @@ kernel void kernel_mul_mm(
78917891 threadgroup_barrier (mem_flags::mem_threadgroup);
78927892
78937893 // load matrices from threadgroup memory and conduct outer products
7894- threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2 ));
7895- threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2 ));
7894+ threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2 ));
7895+ threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2 ));
78967896
78977897 #pragma unroll(4)
78987898 for (short ik = 0 ; ik < BLOCK_SIZE_K/8 ; ik++) {
@@ -8025,7 +8025,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_
80258025template [[host_name(" kernel_mul_mm_id_map0_ne20_10" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10 >;
80268026template [[host_name(" kernel_mul_mm_id_map0_ne20_16" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16 >;
80278027
8028- template <typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread T4x4 &)>
8028+ template <typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread T4x4 &), typename U, typename U2x4 >
80298029kernel void kernel_mul_mm_id (
80308030 constant ggml_metal_kargs_mul_mm_id & args,
80318031 device const char * src0,
@@ -8085,7 +8085,7 @@ kernel void kernel_mul_mm_id(
80858085 device const block_q * x = (device const block_q *)(src0
80868086 + args.nb01 *(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
80878087
8088- device const float * y = (device const float *)(src1
8088+ device const U * y = (device const U *)(src1
80898089 + args.nb13 *i13
80908090 + args.nb12 *i12
80918091 + args.nb11 *i11
@@ -8105,7 +8105,7 @@ kernel void kernel_mul_mm_id(
81058105 + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = temp_a[i/4 ][i%4 ];
81068106 }
81078107
8108- *(threadgroup half2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y));
8108+ *(threadgroup half2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = (half2x4)(*((device U2x4 *) y));
81098109
81108110 il = (il + 2 < nl) ? il + 2 : il % 2 ;
81118111 x = (il < 2 ) ? x + (2 + nl - 1 )/nl : x;
@@ -8306,34 +8306,59 @@ template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_m
83068306// indirect matrix-matrix multiplication
83078307//
83088308
8309- typedef decltype (kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1 , dequantize_f32>) mul_mm_id;
8309+ typedef decltype (kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1 , dequantize_f32, float , float2x4 >) mul_mm_id;
83108310
8311- template [[host_name(" kernel_mul_mm_id_f32_f16 " )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1 , dequantize_f32>;
8312- template [[host_name(" kernel_mul_mm_id_f16_f16 " )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half4x4, 1 , dequantize_f16>;
8311+ template [[host_name(" kernel_mul_mm_id_f32_f32 " )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1 , dequantize_f32, float , float2x4 >;
8312+ template [[host_name(" kernel_mul_mm_id_f16_f32 " )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half4x4, 1 , dequantize_f16, float , float2x4 >;
83138313#if defined(GGML_METAL_HAS_BF16)
8314- template [[host_name(" kernel_mul_mm_id_bf16_f16 " )]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1 , dequantize_bf16>;
8314+ template [[host_name(" kernel_mul_mm_id_bf16_f32 " )]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1 , dequantize_bf16, float , float2x4 >;
83158315#endif
8316- template [[host_name(" kernel_mul_mm_id_q4_0_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_0, 2 , dequantize_q4_0>;
8317- template [[host_name(" kernel_mul_mm_id_q4_1_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_1, 2 , dequantize_q4_1>;
8318- template [[host_name(" kernel_mul_mm_id_q5_0_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2 , dequantize_q5_0>;
8319- template [[host_name(" kernel_mul_mm_id_q5_1_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2 , dequantize_q5_1>;
8320- template [[host_name(" kernel_mul_mm_id_q8_0_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2 , dequantize_q8_0>;
8321- template [[host_name(" kernel_mul_mm_id_mxfp4_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2 , dequantize_mxfp4>;
8322- template [[host_name(" kernel_mul_mm_id_q2_K_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
8323- template [[host_name(" kernel_mul_mm_id_q3_K_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
8324- template [[host_name(" kernel_mul_mm_id_q4_K_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
8325- template [[host_name(" kernel_mul_mm_id_q5_K_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
8326- template [[host_name(" kernel_mul_mm_id_q6_K_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
8327- template [[host_name(" kernel_mul_mm_id_iq2_xxs_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
8328- template [[host_name(" kernel_mul_mm_id_iq2_xs_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
8329- template [[host_name(" kernel_mul_mm_id_iq3_xxs_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
8330- template [[host_name(" kernel_mul_mm_id_iq3_s_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
8331- template [[host_name(" kernel_mul_mm_id_iq2_s_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
8332- template [[host_name(" kernel_mul_mm_id_iq1_s_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
8333- template [[host_name(" kernel_mul_mm_id_iq1_m_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
8334- template [[host_name(" kernel_mul_mm_id_iq4_nl_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2 , dequantize_iq4_nl>;
8335- template [[host_name(" kernel_mul_mm_id_iq4_xs_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
8336-
8316+ template [[host_name(" kernel_mul_mm_id_q4_0_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_0, 2 , dequantize_q4_0, float , float2x4>;
8317+ template [[host_name(" kernel_mul_mm_id_q4_1_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_1, 2 , dequantize_q4_1, float , float2x4>;
8318+ template [[host_name(" kernel_mul_mm_id_q5_0_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2 , dequantize_q5_0, float , float2x4>;
8319+ template [[host_name(" kernel_mul_mm_id_q5_1_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2 , dequantize_q5_1, float , float2x4>;
8320+ template [[host_name(" kernel_mul_mm_id_q8_0_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2 , dequantize_q8_0, float , float2x4>;
8321+ template [[host_name(" kernel_mul_mm_id_mxfp4_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2 , dequantize_mxfp4, float , float2x4>;
8322+ template [[host_name(" kernel_mul_mm_id_q2_K_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float , float2x4>;
8323+ template [[host_name(" kernel_mul_mm_id_q3_K_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float , float2x4>;
8324+ template [[host_name(" kernel_mul_mm_id_q4_K_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float , float2x4>;
8325+ template [[host_name(" kernel_mul_mm_id_q5_K_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float , float2x4>;
8326+ template [[host_name(" kernel_mul_mm_id_q6_K_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float , float2x4>;
8327+ template [[host_name(" kernel_mul_mm_id_iq2_xxs_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float , float2x4>;
8328+ template [[host_name(" kernel_mul_mm_id_iq2_xs_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float , float2x4>;
8329+ template [[host_name(" kernel_mul_mm_id_iq3_xxs_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float , float2x4>;
8330+ template [[host_name(" kernel_mul_mm_id_iq3_s_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float , float2x4>;
8331+ template [[host_name(" kernel_mul_mm_id_iq2_s_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float , float2x4>;
8332+ template [[host_name(" kernel_mul_mm_id_iq1_s_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float , float2x4>;
8333+ template [[host_name(" kernel_mul_mm_id_iq1_m_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float , float2x4>;
8334+ template [[host_name(" kernel_mul_mm_id_iq4_nl_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2 , dequantize_iq4_nl, float , float2x4>;
8335+ template [[host_name(" kernel_mul_mm_id_iq4_xs_f32" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float , float2x4>;
8336+
8337+ template [[host_name(" kernel_mul_mm_id_f32_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1 , dequantize_f32, half, half2x4>;
8338+ template [[host_name(" kernel_mul_mm_id_f16_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half4x4, 1 , dequantize_f16, half, half2x4>;
8339+ #if defined(GGML_METAL_HAS_BF16)
8340+ template [[host_name(" kernel_mul_mm_id_bf16_f16" )]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1 , dequantize_bf16, half, half2x4>;
8341+ #endif
8342+ template [[host_name(" kernel_mul_mm_id_q4_0_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_0, 2 , dequantize_q4_0, half, half2x4>;
8343+ template [[host_name(" kernel_mul_mm_id_q4_1_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_1, 2 , dequantize_q4_1, half, half2x4>;
8344+ template [[host_name(" kernel_mul_mm_id_q5_0_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2 , dequantize_q5_0, half, half2x4>;
8345+ template [[host_name(" kernel_mul_mm_id_q5_1_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2 , dequantize_q5_1, half, half2x4>;
8346+ template [[host_name(" kernel_mul_mm_id_q8_0_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2 , dequantize_q8_0, half, half2x4>;
8347+ template [[host_name(" kernel_mul_mm_id_mxfp4_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2 , dequantize_mxfp4, half, half2x4>;
8348+ template [[host_name(" kernel_mul_mm_id_q2_K_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half2x4>;
8349+ template [[host_name(" kernel_mul_mm_id_q3_K_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, half, half2x4>;
8350+ template [[host_name(" kernel_mul_mm_id_q4_K_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half2x4>;
8351+ template [[host_name(" kernel_mul_mm_id_q5_K_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, half, half2x4>;
8352+ template [[host_name(" kernel_mul_mm_id_q6_K_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, half, half2x4>;
8353+ template [[host_name(" kernel_mul_mm_id_iq2_xxs_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half2x4>;
8354+ template [[host_name(" kernel_mul_mm_id_iq2_xs_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, half, half2x4>;
8355+ template [[host_name(" kernel_mul_mm_id_iq3_xxs_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, half, half2x4>;
8356+ template [[host_name(" kernel_mul_mm_id_iq3_s_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, half, half2x4>;
8357+ template [[host_name(" kernel_mul_mm_id_iq2_s_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, half, half2x4>;
8358+ template [[host_name(" kernel_mul_mm_id_iq1_s_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, half, half2x4>;
8359+ template [[host_name(" kernel_mul_mm_id_iq1_m_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, half, half2x4>;
8360+ template [[host_name(" kernel_mul_mm_id_iq4_nl_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2 , dequantize_iq4_nl, half, half2x4>;
8361+ template [[host_name(" kernel_mul_mm_id_iq4_xs_f16" )]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, half, half2x4>;
83378362
83388363//
83398364// matrix-vector multiplication
0 commit comments