@@ -1219,10 +1219,10 @@ @implementation GGMLMetalClass
1219
1219
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true );
1220
1220
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1221
1221
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
1222
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NORM, norm, true );
1222
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NORM, norm, has_simdgroup_reduction );
1223
1223
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true );
1224
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true );
1225
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true );
1224
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, has_simdgroup_reduction );
1225
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, has_simdgroup_reduction );
1226
1226
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true );
1227
1227
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true );
1228
1228
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
@@ -1443,9 +1443,9 @@ @implementation GGMLMetalClass
1443
1443
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, swiglu_oai, true );
1444
1444
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true );
1445
1445
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true );
1446
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true );
1447
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MEAN, mean, true );
1448
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true );
1446
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, has_simdgroup_reduction );
1447
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MEAN, mean, has_simdgroup_reduction );
1448
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, has_simdgroup_reduction );
1449
1449
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true );
1450
1450
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true );
1451
1451
}
@@ -1982,7 +1982,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1982
1982
case GGML_OP_L2_NORM:
1983
1983
return has_simdgroup_reduction && (op->ne [0 ] % 4 == 0 && ggml_is_contiguous_1 (op->src [0 ]));
1984
1984
case GGML_OP_ARGMAX:
1985
- return true ;
1985
+ return has_simdgroup_reduction ;
1986
1986
case GGML_OP_NORM:
1987
1987
return has_simdgroup_reduction && (op->ne [0 ] % 4 == 0 && ggml_is_contiguous_1 (op->src [0 ]));
1988
1988
case GGML_OP_ROPE:
@@ -2028,6 +2028,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
2028
2028
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
2029
2029
case GGML_OP_SSM_CONV:
2030
2030
case GGML_OP_SSM_SCAN:
2031
+ return has_simdgroup_reduction;
2031
2032
case GGML_OP_RWKV_WKV6:
2032
2033
case GGML_OP_RWKV_WKV7:
2033
2034
return true ;
0 commit comments