Skip to content

Commit a14bd35

Browse files
authored
metal : fix kernel requirements (ggml-org#15983)
* metal : fix kernel requirements ggml-ci * cont : fix supports_op * cont : fix supports_op for ARGMAX
1 parent 918b26f commit a14bd35

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,10 +1219,10 @@ @implementation GGMLMetalClass
12191219
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
12201220
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
12211221
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
1222-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
1222+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, has_simdgroup_reduction);
12231223
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
1224-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
1225-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
1224+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, has_simdgroup_reduction);
1225+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, has_simdgroup_reduction);
12261226
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
12271227
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
12281228
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
@@ -1443,9 +1443,9 @@ @implementation GGMLMetalClass
14431443
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, swiglu_oai, true);
14441444
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
14451445
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
1446-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1447-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
1448-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
1446+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, has_simdgroup_reduction);
1447+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, has_simdgroup_reduction);
1448+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, has_simdgroup_reduction);
14491449
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
14501450
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
14511451
}
@@ -1982,7 +1982,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
19821982
case GGML_OP_L2_NORM:
19831983
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
19841984
case GGML_OP_ARGMAX:
1985-
return true;
1985+
return has_simdgroup_reduction;
19861986
case GGML_OP_NORM:
19871987
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
19881988
case GGML_OP_ROPE:
@@ -2028,6 +2028,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
20282028
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
20292029
case GGML_OP_SSM_CONV:
20302030
case GGML_OP_SSM_SCAN:
2031+
return has_simdgroup_reduction;
20312032
case GGML_OP_RWKV_WKV6:
20322033
case GGML_OP_RWKV_WKV7:
20332034
return true;

0 commit comments

Comments
 (0)