Skip to content

Commit c2a96d7

Browse files
committed
metal : support mul_mm with src1->type == GGML_TYPE_F16
1 parent 72b24d9 commit c2a96d7

File tree

3 files changed

+54
-30
lines changed

3 files changed

+54
-30
lines changed

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -717,8 +717,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
717717
return true;
718718
case GGML_OP_MUL_MAT:
719719
case GGML_OP_MUL_MAT_ID:
720-
return has_simdgroup_reduction &&
721-
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
720+
return has_simdgroup_reduction;
722721
case GGML_OP_CPY:
723722
case GGML_OP_DUP:
724723
case GGML_OP_CONT:

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1477,7 +1477,6 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
14771477
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
14781478
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
14791479
props_dev->has_simdgroup_mm &&
1480-
op->src[1]->type == GGML_TYPE_F32 &&
14811480
ne00 % 32 == 0 && ne00 >= 64 &&
14821481
(ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
14831482
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7868,7 +7868,7 @@ kernel void kernel_set_rows_f(
78687868
#define SG_MAT_ROW 8
78697869

78707870
// each block_q contains 16*nl weights
7871-
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
7871+
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &), typename U, typename U2x4>
78727872
kernel void kernel_mul_mm(
78737873
constant ggml_metal_kargs_mul_mm & args,
78747874
device const char * src0,
@@ -7913,7 +7913,7 @@ kernel void kernel_mul_mm(
79137913
device const block_q * x = (device const block_q *)(src0
79147914
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
79157915

7916-
device const float * y = (device const float *)(src1
7916+
device const U * y = (device const U *)(src1
79177917
+ args.nb13*i13
79187918
+ args.nb12*i12
79197919
+ args.nb11*(r1*BLOCK_SIZE_N + thread_col)
@@ -7933,7 +7933,7 @@ kernel void kernel_mul_mm(
79337933
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
79347934
}
79357935

7936-
*(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
7936+
*(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (float2x4)(*((device U2x4 *) y));
79377937

79387938
il = (il + 2 < nl) ? il + 2 : il % 2;
79397939
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
@@ -8299,33 +8299,59 @@ template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kerne
82998299
// matrix-matrix multiplication
83008300
//
83018301

8302-
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_t;
8302+
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float2x4>) mul_mm_t;
83038303

8304-
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
8305-
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
8304+
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float2x4>;
8305+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, float, float2x4>;
8306+
#if defined(GGML_METAL_HAS_BF16)
8307+
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, float, float2x4>;
8308+
#endif
8309+
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float2x4>;
8310+
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float2x4>;
8311+
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float2x4>;
8312+
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float2x4>;
8313+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float2x4>;
8314+
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float2x4>;
8315+
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float2x4>;
8316+
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float2x4>;
8317+
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float2x4>;
8318+
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float2x4>;
8319+
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float2x4>;
8320+
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float2x4>;
8321+
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float2x4>;
8322+
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float2x4>;
8323+
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float2x4>;
8324+
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float2x4>;
8325+
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float2x4>;
8326+
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float2x4>;
8327+
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float2x4>;
8328+
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float2x4>;
8329+
8330+
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, half, half2x4>;
8331+
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half2x4>;
83068332
#if defined(GGML_METAL_HAS_BF16)
8307-
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
8333+
template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, half, half2x4>;
83088334
#endif
8309-
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
8310-
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
8311-
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
8312-
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
8313-
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
8314-
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
8315-
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
8316-
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
8317-
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
8318-
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
8319-
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
8320-
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
8321-
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
8322-
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
8323-
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
8324-
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
8325-
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
8326-
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
8327-
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
8328-
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
8335+
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, half, half2x4>;
8336+
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, half, half2x4>;
8337+
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, half, half2x4>;
8338+
template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, half, half2x4>;
8339+
template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, half, half2x4>;
8340+
template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, half, half2x4>;
8341+
template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half2x4>;
8342+
template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, half, half2x4>;
8343+
template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half2x4>;
8344+
template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, half, half2x4>;
8345+
template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, half, half2x4>;
8346+
template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half2x4>;
8347+
template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, half, half2x4>;
8348+
template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, half, half2x4>;
8349+
template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, half, half2x4>;
8350+
template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, half, half2x4>;
8351+
template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, half, half2x4>;
8352+
template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, half, half2x4>;
8353+
template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, half, half2x4>;
8354+
template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, half, half2x4>;
83298355

83308356
//
83318357
// indirect matrix-matrix multiplication

0 commit comments

Comments
 (0)