Skip to content

Commit 2b86f84

Browse files
committed
metal : GGML_OP_CPY
1 parent d7488ba commit 2b86f84

File tree

3 files changed

+182
-261
lines changed

3 files changed

+182
-261
lines changed

ggml/src/ggml-common.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,25 @@ typedef struct {
494494
uint64_t nb3;
495495
} ggml_metal_kargs_repeat;
496496

497+
typedef struct {
498+
int64_t ne00;
499+
int64_t ne01;
500+
int64_t ne02;
501+
int64_t ne03;
502+
uint64_t nb00;
503+
uint64_t nb01;
504+
uint64_t nb02;
505+
uint64_t nb03;
506+
int64_t ne0;
507+
int64_t ne1;
508+
int64_t ne2;
509+
int64_t ne3;
510+
uint64_t nb0;
511+
uint64_t nb1;
512+
uint64_t nb2;
513+
uint64_t nb3;
514+
} ggml_metal_kargs_cpy;
515+
497516
typedef struct {
498517
int32_t ne00;
499518
int32_t ne01;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,25 +1381,29 @@ static void ggml_metal_encode_node(
13811381

13821382
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
13831383

1384+
ggml_metal_kargs_cpy args = {
1385+
/*.ne00 =*/ ne00,
1386+
/*.ne01 =*/ ne01,
1387+
/*.ne02 =*/ ne02,
1388+
/*.ne03 =*/ ne03,
1389+
/*.nb00 =*/ nb00,
1390+
/*.nb01 =*/ nb01,
1391+
/*.nb02 =*/ nb02,
1392+
/*.nb03 =*/ nb03,
1393+
/*.ne0 =*/ ne0,
1394+
/*.ne1 =*/ ne1,
1395+
/*.ne2 =*/ ne2,
1396+
/*.ne3 =*/ ne3,
1397+
/*.nb0 =*/ nb0,
1398+
/*.nb1 =*/ nb1,
1399+
/*.nb2 =*/ nb2,
1400+
/*.nb3 =*/ nb3,
1401+
};
1402+
13841403
[encoder setComputePipelineState:pipeline];
1385-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1386-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1387-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1388-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1389-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1390-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1391-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1392-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1393-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1394-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1395-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1396-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1397-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1398-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1399-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1400-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1401-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1402-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1404+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
1405+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1406+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
14031407

14041408
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
14051409

@@ -3429,25 +3433,29 @@ static void ggml_metal_encode_node(
34293433
default: GGML_ABORT("not implemented");
34303434
}
34313435

3436+
ggml_metal_kargs_cpy args = {
3437+
/*.ne00 =*/ ne00,
3438+
/*.ne01 =*/ ne01,
3439+
/*.ne02 =*/ ne02,
3440+
/*.ne03 =*/ ne03,
3441+
/*.nb00 =*/ nb00,
3442+
/*.nb01 =*/ nb01,
3443+
/*.nb02 =*/ nb02,
3444+
/*.nb03 =*/ nb03,
3445+
/*.ne0 =*/ ne0,
3446+
/*.ne1 =*/ ne1,
3447+
/*.ne2 =*/ ne2,
3448+
/*.ne3 =*/ ne3,
3449+
/*.nb0 =*/ nb0,
3450+
/*.nb1 =*/ nb1,
3451+
/*.nb2 =*/ nb2,
3452+
/*.nb3 =*/ nb3,
3453+
};
3454+
34323455
[encoder setComputePipelineState:pipeline];
3433-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3434-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3435-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
3436-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
3437-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
3438-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
3439-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
3440-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
3441-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
3442-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
3443-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
3444-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
3445-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
3446-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
3447-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
3448-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
3449-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
3450-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
3456+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
3457+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3458+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
34513459

34523460
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
34533461
} break;

0 commit comments

Comments
 (0)