@@ -1870,6 +1870,8 @@ kernel void kernel_mul_mv_q8_0_f32(
18701870 kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr , tgpig, tiisg, sgitg);
18711871}
18721872
1873+ // mat-vec kernel processing in chunks of float4
1874+ // chpb - chunks per quantization block
18731875template <short nxpsg, short r1ptg, typename q_t , short chpb, void (*deq_t4)(device const q_t *, short , thread float4 &) >
18741876void kernel_mul_mv_ext_q4_f32_impl (
18751877 constant ggml_metal_kargs_mul_mv_ext & args,
@@ -1879,7 +1881,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
18791881 uint3 tgpig[[threadgroup_position_in_grid]],
18801882 ushort tiisg[[thread_index_in_simdgroup]],
18811883 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1882- const short chpt = 4 ;
1884+ const short chpt = 4 ; // chunks per thread
18831885
18841886 // const short nxpsg = (32);
18851887 const short nypsg = (32 /nxpsg);
@@ -1907,7 +1909,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
19071909
19081910 float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0 .0f };
19091911
1910- short cch = tx%chpb;
1912+ short cch = tx%chpb; // current chunk index
19111913
19121914 for (int ich = tx; 4 *ich < args.ne00 ; ich += chpt*nxpsg) {
19131915 float4 lx[chpt];
@@ -1938,6 +1940,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
19381940 }
19391941 }
19401942
1943+ // reduce only the threads in each row
19411944 for (short ir1 = 0 ; ir1 < r1ptg; ++ir1) {
19421945 if (nxpsg >= 32 ) {
19431946 sumf[ir1] += simd_shuffle_down (sumf[ir1], 16 );
@@ -1969,6 +1972,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
19691972 }
19701973}
19711974
1975+ // mat-vec kernel processing in chunks of float4x4
19721976template <short nxpsg, short r1ptg, typename q_t , short chpb, void (*deq_t4x4)(device const q_t *, short , thread float4x4 &) >
19731977void kernel_mul_mv_ext_q4x4_f32_impl (
19741978 constant ggml_metal_kargs_mul_mv_ext & args,
@@ -2072,6 +2076,8 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
20722076 }
20732077}
20742078
2079+ // dispatchers needed for compile-time nxpsg
2080+ // epb - elements per quantization block
20752081template <short r1ptg, typename q_t , short epb, void (*deq_t4)(device const q_t *, short , thread float4 &)>
20762082kernel void kernel_mul_mv_ext_q4_f32_disp (
20772083 constant ggml_metal_kargs_mul_mv_ext & args,
0 commit comments