@@ -1219,10 +1219,10 @@ @implementation GGMLMetalClass
12191219 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true );
12201220 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
12211221 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
1222- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NORM, norm, true );
1222+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NORM, norm, has_simdgroup_reduction );
12231223 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true );
1224- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true );
1225- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true );
1224+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, has_simdgroup_reduction );
1225+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, has_simdgroup_reduction );
12261226 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true );
12271227 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true );
12281228 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
@@ -1443,9 +1443,9 @@ @implementation GGMLMetalClass
14431443 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, swiglu_oai, true );
14441444 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true );
14451445 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true );
1446- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true );
1447- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MEAN, mean, true );
1448- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true );
1446+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, has_simdgroup_reduction );
1447+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MEAN, mean, has_simdgroup_reduction );
1448+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, has_simdgroup_reduction );
14491449 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true );
14501450 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true );
14511451 }
@@ -1982,7 +1982,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
19821982 case GGML_OP_L2_NORM:
19831983 return has_simdgroup_reduction && (op->ne [0 ] % 4 == 0 && ggml_is_contiguous_1 (op->src [0 ]));
19841984 case GGML_OP_ARGMAX:
1985- return true ;
1985+ return has_simdgroup_reduction ;
19861986 case GGML_OP_NORM:
19871987 return has_simdgroup_reduction && (op->ne [0 ] % 4 == 0 && ggml_is_contiguous_1 (op->src [0 ]));
19881988 case GGML_OP_ROPE:
@@ -2028,6 +2028,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
20282028 return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
20292029 case GGML_OP_SSM_CONV:
20302030 case GGML_OP_SSM_SCAN:
2031+ return has_simdgroup_reduction;
20312032 case GGML_OP_RWKV_WKV6:
20322033 case GGML_OP_RWKV_WKV7:
20332034 return true ;
0 commit comments