Skip to content

Commit 434fc45

Browse files
committed
metal : add comments
ggml-ci
1 parent 5590160 commit 434fc45

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2008,8 +2008,10 @@ static void ggml_metal_encode_node(
20082008

20092009
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
20102010
// to the matrix-vector kernel
2011-
int ne11_mm_min = 4;
2011+
const int ne11_mm_min = 4;
20122012

2013+
// first try to use small-batch mat-mv kernels
2014+
// these should be efficient for BS [2, ~8]
20132015
if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
20142016
(
20152017
(
@@ -2033,12 +2035,20 @@ static void ggml_metal_encode_node(
20332035
)
20342036
) {
20352037
// TODO: determine the optimal parameters based on grid utilization
2036-
const int nsg = 2; // TODO: or 4?
2037-
const int nxpsg = ne11 < 3 ? 16 : 8;
2038-
const int nypsg = 32/nxpsg;
2039-
const int r0ptg = nypsg*nsg;
2040-
int r1ptg = 4;
2038+
// I still don't know why we should not always use the maximum available threads:
2039+
//
2040+
// nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
2041+
//
2042+
// my current hypothesis is that the work grid is not evenly divisible for different nsg
2043+
// values and there can be some tail effects when nsg is high. need to confirm this
2044+
//
2045+
const int nsg = 2; // num simdgroups per threadgroup
2046+
const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
2047+
const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
2048+
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
2049+
int r1ptg = 4; // num src1 rows per threadgroup
20412050

2051+
// note: not sure how optimal are those across all different hardware. there might be someting cleverer
20422052
switch (ne11) {
20432053
case 2:
20442054
r1ptg = 2; break;

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1870,6 +1870,8 @@ kernel void kernel_mul_mv_q8_0_f32(
18701870
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
18711871
}
18721872

1873+
// mat-vec kernel processing in chunks of float4
1874+
// chpb - chunks per quantization block
18731875
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
18741876
void kernel_mul_mv_ext_q4_f32_impl(
18751877
constant ggml_metal_kargs_mul_mv_ext & args,
@@ -1879,7 +1881,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
18791881
uint3 tgpig[[threadgroup_position_in_grid]],
18801882
ushort tiisg[[thread_index_in_simdgroup]],
18811883
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1882-
const short chpt = 4;
1884+
const short chpt = 4; // chunks per thread
18831885

18841886
//const short nxpsg = (32);
18851887
const short nypsg = (32/nxpsg);
@@ -1907,7 +1909,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
19071909

19081910
float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
19091911

1910-
short cch = tx%chpb;
1912+
short cch = tx%chpb; // current chunk index
19111913

19121914
for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) {
19131915
float4 lx[chpt];
@@ -1938,6 +1940,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
19381940
}
19391941
}
19401942

1943+
// reduce only the threads in each row
19411944
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
19421945
if (nxpsg >= 32) {
19431946
sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
@@ -1969,6 +1972,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
19691972
}
19701973
}
19711974

1975+
// mat-vec kernel processing in chunks of float4x4
19721976
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
19731977
void kernel_mul_mv_ext_q4x4_f32_impl(
19741978
constant ggml_metal_kargs_mul_mv_ext & args,
@@ -2072,6 +2076,8 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
20722076
}
20732077
}
20742078

2079+
// dispatchers needed for compile-time nxpsg
2080+
// epb - elements per quantization block
20752081
template<short r1ptg, typename q_t, short epb, void (*deq_t4)(device const q_t *, short, thread float4 &)>
20762082
kernel void kernel_mul_mv_ext_q4_f32_disp(
20772083
constant ggml_metal_kargs_mul_mv_ext & args,

0 commit comments

Comments
 (0)