@@ -2883,7 +2883,8 @@ static inline void helper_mv_reduce_and_write(
28832883 }
28842884}
28852885
2886- constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0 )]];
2886+ constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0 )]];
2887+ constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1 )]];
28872888
28882889template <typename block_q_type, short NR0, short NW, typename args_t >
28892890void mul_vec_q_n_f32_impl (
@@ -3108,7 +3109,7 @@ kernel void kernel_mul_mv_q8_0_f32(
31083109
31093110// mat-vec kernel processing in chunks of float4
31103111// chpb - chunks per quantization block
3111- template <short nxpsg, short r1ptg, typename q_t , short chpb, void (*deq_t4)(device const q_t *, short , thread float4 &) >
3112+ template <short r1ptg, typename q_t , short chpb, void (*deq_t4)(device const q_t *, short , thread float4 &) >
31123113void kernel_mul_mv_ext_q4_f32_impl (
31133114 constant ggml_metal_kargs_mul_mv_ext & args,
31143115 device const char * src0,
@@ -3117,6 +3118,9 @@ void kernel_mul_mv_ext_q4_f32_impl(
31173118 uint3 tgpig[[threadgroup_position_in_grid]],
31183119 ushort tiisg[[thread_index_in_simdgroup]],
31193120 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3121+ const short NSG = FC_mul_mv_nsg;
3122+ const short nxpsg = FC_mul_mv_nxpsg;
3123+
31203124 const short chpt = 4 ; // chunks per thread
31213125
31223126 // const short nxpsg = (32);
@@ -3125,7 +3129,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
31253129 const short tx = tiisg%nxpsg;
31263130 const short ty = tiisg/nxpsg;
31273131
3128- const int i01 = tgpig.x *(nypsg*args. nsg ) + nypsg*sgitg + ty;
3132+ const int i01 = tgpig.x *(nypsg*NSG ) + nypsg*sgitg + ty;
31293133 const int i11 = tgpig.y *r1ptg;
31303134 const int i1m = tgpig.z ;
31313135
@@ -3208,7 +3212,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
32083212}
32093213
32103214// mat-vec kernel processing in chunks of float4x4
3211- template <short nxpsg, short r1ptg, typename q_t , short chpb, void (*deq_t4x4)(device const q_t *, short , thread float4x4 &) >
3215+ template <short r1ptg, typename q_t , short chpb, void (*deq_t4x4)(device const q_t *, short , thread float4x4 &) >
32123216void kernel_mul_mv_ext_q4x4_f32_impl (
32133217 constant ggml_metal_kargs_mul_mv_ext & args,
32143218 device const char * src0,
@@ -3217,6 +3221,9 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
32173221 uint3 tgpig[[threadgroup_position_in_grid]],
32183222 ushort tiisg[[thread_index_in_simdgroup]],
32193223 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3224+ const short NSG = FC_mul_mv_nsg;
3225+ const short nxpsg = FC_mul_mv_nxpsg;
3226+
32203227 const short chpt = 1 ;
32213228
32223229 // const short nxpsg = (32);
@@ -3225,7 +3232,7 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
32253232 const short tx = tiisg%nxpsg;
32263233 const short ty = tiisg/nxpsg;
32273234
3228- const int i01 = tgpig.x *(nypsg*args. nsg ) + nypsg*sgitg + ty;
3235+ const int i01 = tgpig.x *(nypsg*NSG ) + nypsg*sgitg + ty;
32293236 const int i11 = tgpig.y *r1ptg;
32303237 const int i1m = tgpig.z ;
32313238
@@ -3322,12 +3329,7 @@ kernel void kernel_mul_mv_ext_q4_f32_disp(
33223329 uint3 tgpig[[threadgroup_position_in_grid]],
33233330 ushort tiisg[[thread_index_in_simdgroup]],
33243331 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3325- switch (args.nxpsg ) {
3326- case 4 : kernel_mul_mv_ext_q4_f32_impl<4 , r1ptg, q_t , epb/4 , deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break ;
3327- case 8 : kernel_mul_mv_ext_q4_f32_impl<8 , r1ptg, q_t , epb/4 , deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break ;
3328- case 16 : kernel_mul_mv_ext_q4_f32_impl<16 , r1ptg, q_t , epb/4 , deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break ;
3329- case 32 : kernel_mul_mv_ext_q4_f32_impl<32 , r1ptg, q_t , epb/4 , deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break ;
3330- }
3332+ kernel_mul_mv_ext_q4_f32_impl<r1ptg, q_t , epb/4 , deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
33313333}
33323334
33333335template <short r1ptg, typename q_t , short epb, void (*deq_t4x4)(device const q_t *, short , thread float4x4 &)>
@@ -3339,12 +3341,7 @@ kernel void kernel_mul_mv_ext_q4x4_f32_disp(
33393341 uint3 tgpig[[threadgroup_position_in_grid]],
33403342 ushort tiisg[[thread_index_in_simdgroup]],
33413343 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3342- switch (args.nxpsg ) {
3343- case 4 : kernel_mul_mv_ext_q4x4_f32_impl<4 , r1ptg, q_t , epb/16 , deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break ;
3344- case 8 : kernel_mul_mv_ext_q4x4_f32_impl<8 , r1ptg, q_t , epb/16 , deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break ;
3345- case 16 : kernel_mul_mv_ext_q4x4_f32_impl<16 , r1ptg, q_t , epb/16 , deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break ;
3346- case 32 : kernel_mul_mv_ext_q4x4_f32_impl<32 , r1ptg, q_t , epb/16 , deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break ;
3347- }
3344+ kernel_mul_mv_ext_q4x4_f32_impl<r1ptg, q_t , epb/16 , deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
33483345}
33493346
33503347typedef decltype (kernel_mul_mv_ext_q4_f32_disp <2 , block_q8_0, 32 , dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
0 commit comments