@@ -396,7 +396,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
396396 GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
397397 GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
398398 GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
399- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
399+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
400+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
401+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
402+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
403+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
404+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
400405 GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
401406 GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
402407 GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
@@ -1411,7 +1416,12 @@ @implementation GGMLMetalClass
14111416 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
14121417 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
14131418 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
1414- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
1419+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm);
1420+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm);
1421+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
1422+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
1423+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
1424+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
14151425 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
14161426 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
14171427 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
@@ -3907,7 +3917,17 @@ static int ggml_metal_encode_node(
39073917
39083918 id <MTLComputePipelineState > pipeline = nil ;
39093919
3910- pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline ;
3920+ pipeline = nil ;
3921+
3922+ switch (ne20) {
3923+ case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline ; break ;
3924+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline ; break ;
3925+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline ; break ;
3926+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline ; break ;
3927+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline ; break ;
3928+ case 16 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline ; break ;
3929+ default : GGML_ABORT (" missing specialization for ne20 = %d " , (int ) ne20);
3930+ }
39113931
39123932 GGML_ASSERT (ne02 <= (int ) pipeline.maxTotalThreadsPerThreadgroup );
39133933
0 commit comments