@@ -398,8 +398,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
398
398
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
399
399
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
400
400
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
401
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
402
- GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
401
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
402
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
403
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
404
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
405
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
406
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
403
407
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
404
408
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
405
409
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
@@ -1428,8 +1432,12 @@ @implementation GGMLMetalClass
1428
1432
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
1429
1433
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
1430
1434
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
1431
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
1432
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
1435
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm);
1436
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm);
1437
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
1438
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
1439
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
1440
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
1433
1441
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
1434
1442
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
1435
1443
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
@@ -3908,38 +3916,6 @@ static int ggml_metal_encode_node(
3908
3916
default : break ;
3909
3917
}
3910
3918
3911
- const int64_t neh10 = ne10; // n_embd
3912
- const int64_t neh11 = ne21; // n_tokens
3913
- const int64_t neh12 = ne02; // n_expert
3914
-
3915
- const uint64_t nbh10 = ggml_type_size (GGML_TYPE_F16);
3916
- const uint64_t nbh11 = nbh10*neh10;
3917
- const uint64_t nbh12 = nbh11*neh11;
3918
- const uint64_t nbh13 = nbh12*neh12;
3919
-
3920
- const size_t s_src1 = ggml_type_size (GGML_TYPE_F16)*neh10*neh11*neh12;
3921
- id <MTLBuffer > h_src1 = ggml_metal_mem_pool_alloc (mem_pool, s_src1);
3922
- if (!h_src1) {
3923
- GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_src1);
3924
- return 0 ;
3925
- }
3926
-
3927
- const int64_t neh0 = ne0;
3928
- const int64_t neh1 = ne21;
3929
- const int64_t neh2 = ne02;
3930
-
3931
- const uint64_t nbh0 = ggml_type_size (GGML_TYPE_F32);
3932
- const uint64_t nbh1 = nbh0*neh0;
3933
- const uint64_t nbh2 = nbh1*neh1;
3934
- // const uint64_t nbh3 = nbh2*neh2;
3935
-
3936
- const size_t s_dst = ggml_type_size (GGML_TYPE_F32)*neh0*neh1*neh2;
3937
- id <MTLBuffer > h_dst = ggml_metal_mem_pool_alloc (mem_pool, s_dst);
3938
- if (!h_dst) {
3939
- GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_dst);
3940
- return 0 ;
3941
- }
3942
-
3943
3919
// tokens per expert
3944
3920
const size_t s_tpe = ggml_type_size (GGML_TYPE_I32)*ne02;
3945
3921
id <MTLBuffer > h_tpe = ggml_metal_mem_pool_alloc (mem_pool, s_tpe);
@@ -3949,41 +3925,54 @@ static int ggml_metal_encode_node(
3949
3925
}
3950
3926
3951
3927
// id map
3952
- // [n_expert_used, n_tokens ]
3953
- const size_t s_ids = ggml_type_size (GGML_TYPE_I32)*ne20* ne21;
3928
+ // [n_tokens, n_expert ]
3929
+ const size_t s_ids = ggml_type_size (GGML_TYPE_I32)*ne21*ne02 ;
3954
3930
id <MTLBuffer > h_ids = ggml_metal_mem_pool_alloc (mem_pool, s_ids);
3955
3931
if (!h_ids) {
3956
3932
GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_ids);
3957
3933
return 0 ;
3958
3934
}
3959
3935
3960
3936
{
3961
- const int nth = MIN (1024 , ne10/4 );
3962
-
3963
3937
ggml_metal_kargs_mul_mm_id_map0 args = {
3938
+ ne02,
3964
3939
ne10,
3965
- ne11, // n_expert_used (bcast)
3940
+ ne11, // n_expert_used (bcast)
3966
3941
nb11,
3967
3942
nb12,
3968
- neh11, // n_tokens
3969
- nbh11,
3970
- ne20, // n_expert_used
3943
+ ne21, // n_tokens
3944
+ ne20, // n_expert_used
3971
3945
nb21,
3972
3946
};
3973
3947
3974
3948
id <MTLComputePipelineState > pipeline = nil ;
3975
3949
3976
- pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline ;
3950
+ pipeline = nil ;
3951
+
3952
+ switch (ne20) {
3953
+ case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline ; break ;
3954
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline ; break ;
3955
+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline ; break ;
3956
+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline ; break ;
3957
+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline ; break ;
3958
+ case 16 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline ; break ;
3959
+ default : GGML_ABORT (" missing specialization for ne20 = %d " , (int ) ne20);
3960
+ }
3961
+
3962
+ GGML_ASSERT (ne02 <= (int ) pipeline.maxTotalThreadsPerThreadgroup );
3963
+
3964
+ const size_t smem = ne02*ne20*sizeof (uint16_t );
3965
+
3966
+ GGML_ASSERT (smem <= device.maxThreadgroupMemoryLength );
3977
3967
3978
3968
[encoder setComputePipelineState: pipeline];
3979
3969
[encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
3980
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
3981
- [encoder setBuffer: id_src2 offset: offs_src2 atIndex: 2 ];
3982
- [encoder setBuffer: h_src1 offset: 0 atIndex: 3 ];
3983
- [encoder setBuffer: h_tpe offset: 0 atIndex: 4 ];
3984
- [encoder setBuffer: h_ids offset: 0 atIndex: 5 ];
3970
+ [encoder setBuffer: id_src2 offset: offs_src2 atIndex: 1 ];
3971
+ [encoder setBuffer: h_tpe offset: 0 atIndex: 2 ];
3972
+ [encoder setBuffer: h_ids offset: 0 atIndex: 3 ];
3973
+ [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
3985
3974
3986
- [encoder dispatchThreadgroups: MTLSizeMake (ne02 , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (nth , 1 , 1 )];
3975
+ [encoder dispatchThreadgroups: MTLSizeMake (1 , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (ne02 , 1 , 1 )];
3987
3976
}
3988
3977
3989
3978
{
@@ -4022,56 +4011,30 @@ static int ggml_metal_encode_node(
4022
4011
/* .nb01 =*/ nb01,
4023
4012
/* .nb02 =*/ nb02,
4024
4013
/* .nb03 =*/ nb03,
4025
- /* .neh12 =*/ neh12,
4026
- /* .nbh10 =*/ nbh10,
4027
- /* .nbh11 =*/ nbh11,
4028
- /* .nbh12 =*/ nbh12,
4029
- /* .nbh13 =*/ nbh13,
4030
- /* .neh0 =*/ neh0,
4031
- /* .neh1 =*/ neh1,
4014
+ /* .ne11 =*/ ne11, // n_expert_used (bcast)
4015
+ /* .nb10 =*/ nb10,
4016
+ /* .nb11 =*/ nb11,
4017
+ /* .nb12 =*/ nb12,
4018
+ /* .nb13 =*/ nb13,
4019
+ /* .ne20 =*/ ne20, // n_expert_used
4020
+ /* .ne21 =*/ ne21, // n_tokens
4021
+ /* .ne0 =*/ ne0,
4022
+ /* .ne1 =*/ ne1,
4032
4023
/* .r2 =*/ r2,
4033
4024
/* .r3 =*/ r3,
4034
4025
};
4035
4026
4036
4027
[encoder setComputePipelineState: pipeline];
4037
4028
[encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
4038
4029
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
4039
- [encoder setBuffer: h_src1 offset: 0 atIndex: 2 ];
4030
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
4040
4031
[encoder setBuffer: h_tpe offset: 0 atIndex: 3 ];
4041
- [encoder setBuffer: h_dst offset: 0 atIndex: 4 ];
4032
+ [encoder setBuffer: h_ids offset: 0 atIndex: 4 ];
4033
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 5 ];
4042
4034
4043
4035
[encoder setThreadgroupMemoryLength: 8192 atIndex: 0 ];
4044
4036
[encoder dispatchThreadgroups: MTLSizeMake ((ne21 + 31 )/32 , (ne01 + 63 )/64 , ne02) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
4045
4037
}
4046
-
4047
- {
4048
- GGML_ASSERT (ne0 % 4 == 0 );
4049
-
4050
- const int nth = MIN (1024 , ne0/4 );
4051
-
4052
- ggml_metal_kargs_mul_mm_id_map1 args = {
4053
- ne20, // n_expert_used
4054
- neh0,
4055
- neh1,
4056
- nbh1,
4057
- nbh2,
4058
- ne0,
4059
- nb1,
4060
- nb2,
4061
- };
4062
-
4063
- id <MTLComputePipelineState > pipeline = nil ;
4064
-
4065
- pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline ;
4066
-
4067
- [encoder setComputePipelineState: pipeline];
4068
- [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
4069
- [encoder setBuffer: h_dst offset: 0 atIndex: 1 ];
4070
- [encoder setBuffer: h_ids offset: 0 atIndex: 2 ];
4071
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
4072
-
4073
- [encoder dispatchThreadgroups: MTLSizeMake (ne20, ne21, 1 ) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
4074
- }
4075
4038
} else {
4076
4039
id <MTLComputePipelineState > pipeline = nil ;
4077
4040
0 commit comments