Skip to content

Commit 97819a0

Browse files
committed
metal : batch rows copy in a single threadgroup
ggml-ci
1 parent 716301d commit 97819a0

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4974,8 +4974,37 @@ static bool ggml_metal_encode_node(
49744974
default: GGML_ABORT("not implemented");
49754975
}
49764976

4977+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
4978+
4979+
// TODO: support
4980+
//const int32_t nk00 = ne00/ggml_blck_size(dst->type);
4981+
const int32_t nk00 = ne00;
4982+
4983+
int nth = 32; // SIMD width
4984+
4985+
while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
4986+
nth *= 2;
4987+
}
4988+
4989+
// when rows are small, we can batch them together in a single threadgroup
4990+
int nrptg = 1;
4991+
4992+
// TODO: relax this constraint in the future
4993+
if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) {
4994+
if (nth > nk00) {
4995+
nrptg = (nth + nk00 - 1)/nk00;
4996+
nth = nk00;
4997+
4998+
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
4999+
nrptg--;
5000+
}
5001+
}
5002+
}
5003+
5004+
nth = MIN(nth, nk00);
5005+
49775006
ggml_metal_kargs_cpy args = {
4978-
/*.ne00 =*/ ne00,
5007+
/*.ne00 =*/ nk00,
49795008
/*.ne01 =*/ ne01,
49805009
/*.ne02 =*/ ne02,
49815010
/*.ne03 =*/ ne03,
@@ -4998,11 +5027,7 @@ static bool ggml_metal_encode_node(
49985027
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
49995028
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
50005029

5001-
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
5002-
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
5003-
5004-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
5005-
5030+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
50065031
} break;
50075032
case GGML_OP_SET:
50085033
{

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4306,11 +4306,16 @@ kernel void kernel_cpy(
43064306
device const char * src0,
43074307
device char * dst,
43084308
uint3 tgpig[[threadgroup_position_in_grid]],
4309+
uint tiitg[[thread_index_in_threadgroup]],
43094310
ushort3 tpitg[[thread_position_in_threadgroup]],
4310-
ushort3 ntg[[threads_per_threadgroup]]) {
4311+
ushort3 tptg[[threads_per_threadgroup]]) {
43114312
const int i03 = tgpig[2];
43124313
const int i02 = tgpig[1];
4313-
const int i01 = tgpig[0];
4314+
const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
4315+
4316+
if (i01 >= args.ne01) {
4317+
return;
4318+
}
43144319

43154320
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
43164321

@@ -4321,7 +4326,7 @@ kernel void kernel_cpy(
43214326

43224327
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
43234328

4324-
for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
4329+
for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
43254330
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
43264331
dst_data[i00] = (T1) src[0];
43274332
}

0 commit comments

Comments
 (0)