Skip to content

Commit d7488ba

Browse files
committed
metal : GGML_OP_REPEAT
1 parent 281fa05 commit d7488ba

File tree

3 files changed

+56
-48
lines changed

3 files changed

+56
-48
lines changed

ggml/src/ggml-common.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,25 @@ typedef struct {
475475
uint64_t offs;
476476
} ggml_metal_kargs_bin;
477477

478+
typedef struct {
479+
int32_t ne00;
480+
int32_t ne01;
481+
int32_t ne02;
482+
int32_t ne03;
483+
uint64_t nb00;
484+
uint64_t nb01;
485+
uint64_t nb02;
486+
uint64_t nb03;
487+
int32_t ne0;
488+
int32_t ne1;
489+
int32_t ne2;
490+
int32_t ne3;
491+
uint64_t nb0;
492+
uint64_t nb1;
493+
uint64_t nb2;
494+
uint64_t nb3;
495+
} ggml_metal_kargs_repeat;
496+
478497
typedef struct {
479498
int32_t ne00;
480499
int32_t ne01;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,25 +1330,29 @@ static void ggml_metal_encode_node(
13301330
default: GGML_ABORT("fatal error");
13311331
}
13321332

1333+
ggml_metal_kargs_repeat args = {
1334+
/*.ne00 =*/ ne00,
1335+
/*.ne01 =*/ ne01,
1336+
/*.ne02 =*/ ne02,
1337+
/*.ne03 =*/ ne03,
1338+
/*.nb00 =*/ nb00,
1339+
/*.nb01 =*/ nb01,
1340+
/*.nb02 =*/ nb02,
1341+
/*.nb03 =*/ nb03,
1342+
/*.ne0 =*/ ne0,
1343+
/*.ne1 =*/ ne1,
1344+
/*.ne2 =*/ ne2,
1345+
/*.ne3 =*/ ne3,
1346+
/*.nb0 =*/ nb0,
1347+
/*.nb1 =*/ nb1,
1348+
/*.nb2 =*/ nb2,
1349+
/*.nb3 =*/ nb3,
1350+
};
1351+
13331352
[encoder setComputePipelineState:pipeline];
1334-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1335-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1336-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1337-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1338-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1339-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1340-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1341-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1342-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1343-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1344-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
1345-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
1346-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
1347-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
1348-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
1349-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1350-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
1351-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
1353+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
1354+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1355+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
13521356

13531357
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
13541358

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -603,41 +603,26 @@ kernel void kernel_div(
603603

604604
template<typename T>
605605
kernel void kernel_repeat(
606+
constant ggml_metal_kargs_repeat & args,
606607
device const char * src0,
607608
device char * dst,
608-
constant int64_t & ne00,
609-
constant int64_t & ne01,
610-
constant int64_t & ne02,
611-
constant int64_t & ne03,
612-
constant uint64_t & nb00,
613-
constant uint64_t & nb01,
614-
constant uint64_t & nb02,
615-
constant uint64_t & nb03,
616-
constant int64_t & ne0,
617-
constant int64_t & ne1,
618-
constant int64_t & ne2,
619-
constant int64_t & ne3,
620-
constant uint64_t & nb0,
621-
constant uint64_t & nb1,
622-
constant uint64_t & nb2,
623-
constant uint64_t & nb3,
624-
uint3 tgpig[[threadgroup_position_in_grid]],
625-
uint3 tpitg[[thread_position_in_threadgroup]],
626-
uint3 ntg[[threads_per_threadgroup]]) {
627-
const int64_t i3 = tgpig.z;
628-
const int64_t i2 = tgpig.y;
629-
const int64_t i1 = tgpig.x;
609+
uint3 tgpig[[threadgroup_position_in_grid]],
610+
ushort3 tpitg[[thread_position_in_threadgroup]],
611+
ushort3 ntg[[threads_per_threadgroup]]) {
612+
const int i3 = tgpig.z;
613+
const int i2 = tgpig.y;
614+
const int i1 = tgpig.x;
630615

631-
const int64_t i03 = i3 % ne03;
632-
const int64_t i02 = i2 % ne02;
633-
const int64_t i01 = i1 % ne01;
616+
const int i03 = i3%args.ne03;
617+
const int i02 = i2%args.ne02;
618+
const int i01 = i1%args.ne01;
634619

635-
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
636-
device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
620+
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
621+
device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
637622

638-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
639-
const int i00 = i0 % ne00;
640-
*((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
623+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
624+
const int i00 = i0%args.ne00;
625+
*((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
641626
}
642627
}
643628

0 commit comments

Comments
 (0)