Skip to content

Commit 619c157

Browse files
author
alexju
committed
metal : refactor upscale parameters into a struct
1 parent dc0e6ce commit 619c157

File tree

3 files changed

+55
-48
lines changed

3 files changed

+55
-48
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,4 +422,27 @@ typedef struct {
422422
uint64_t nb2;
423423
} ggml_metal_kargs_get_rows;
424424

425+
typedef struct {
426+
int64_t ne00;
427+
int64_t ne01;
428+
int64_t ne02;
429+
int64_t ne03;
430+
uint64_t nb00;
431+
uint64_t nb01;
432+
uint64_t nb02;
433+
uint64_t nb03;
434+
int64_t ne0;
435+
int64_t ne1;
436+
int64_t ne2;
437+
int64_t ne3;
438+
uint64_t nb0;
439+
uint64_t nb1;
440+
uint64_t nb2;
441+
uint64_t nb3;
442+
float sf0;
443+
float sf1;
444+
float sf2;
445+
float sf3;
446+
} ggml_metal_kargs_upscale;
447+
425448
#endif // GGML_METAL_IMPL

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3411,30 +3411,33 @@ static void ggml_metal_encode_node(
34113411

34123412
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
34133413

3414-
// TODO: add ggml_metal_kargs struct
3414+
ggml_metal_kargs_upscale args = {
3415+
/*.ne00 =*/ ne00,
3416+
/*.ne01 =*/ ne01,
3417+
/*.ne02 =*/ ne02,
3418+
/*.ne03 =*/ ne03,
3419+
/*.nb00 =*/ nb00,
3420+
/*.nb01 =*/ nb01,
3421+
/*.nb02 =*/ nb02,
3422+
/*.nb03 =*/ nb03,
3423+
/*.ne0 =*/ ne0,
3424+
/*.ne1 =*/ ne1,
3425+
/*.ne2 =*/ ne2,
3426+
/*.ne3 =*/ ne3,
3427+
/*.nb0 =*/ nb0,
3428+
/*.nb1 =*/ nb1,
3429+
/*.nb2 =*/ nb2,
3430+
/*.nb3 =*/ nb3,
3431+
/*.sf0 =*/ sf0,
3432+
/*.sf1 =*/ sf1,
3433+
/*.sf2 =*/ sf2,
3434+
/*.sf3 =*/ sf3
3435+
};
3436+
34153437
[encoder setComputePipelineState:pipeline];
34163438
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
34173439
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3418-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
3419-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
3420-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
3421-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
3422-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
3423-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
3424-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
3425-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
3426-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
3427-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
3428-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
3429-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
3430-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
3431-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
3432-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
3433-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
3434-
[encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
3435-
[encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
3436-
[encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
3437-
[encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
3440+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
34383441

34393442
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
34403443

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2661,26 +2661,7 @@ kernel void kernel_conv_transpose_1d<half>(
26612661
kernel void kernel_upscale_f32(
26622662
device const char * src0,
26632663
device char * dst,
2664-
constant int64_t & ne00,
2665-
constant int64_t & ne01,
2666-
constant int64_t & ne02,
2667-
constant int64_t & ne03,
2668-
constant uint64_t & nb00,
2669-
constant uint64_t & nb01,
2670-
constant uint64_t & nb02,
2671-
constant uint64_t & nb03,
2672-
constant int64_t & ne0,
2673-
constant int64_t & ne1,
2674-
constant int64_t & ne2,
2675-
constant int64_t & ne3,
2676-
constant uint64_t & nb0,
2677-
constant uint64_t & nb1,
2678-
constant uint64_t & nb2,
2679-
constant uint64_t & nb3,
2680-
constant float & sf0,
2681-
constant float & sf1,
2682-
constant float & sf2,
2683-
constant float & sf3,
2664+
constant ggml_metal_kargs_upscale & args,
26842665
uint3 tgpig[[threadgroup_position_in_grid]],
26852666
uint3 tpitg[[thread_position_in_threadgroup]],
26862667
uint3 ntg[[threads_per_threadgroup]]) {
@@ -2689,15 +2670,15 @@ kernel void kernel_upscale_f32(
26892670
const int64_t i2 = tgpig.y;
26902671
const int64_t i1 = tgpig.x;
26912672

2692-
const int64_t i03 = i3/sf3;
2693-
const int64_t i02 = i2/sf2;
2694-
const int64_t i01 = i1/sf1;
2673+
const int64_t i03 = i3/args.sf3;
2674+
const int64_t i02 = i2/args.sf2;
2675+
const int64_t i01 = i1/args.sf1;
26952676

2696-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
2697-
const int64_t i00 = i0/sf0;
2677+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
2678+
const int64_t i00 = i0/args.sf0;
26982679

2699-
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2700-
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2680+
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
2681+
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
27012682

27022683
dst_ptr[0] = src0_ptr[0];
27032684
}

0 commit comments

Comments
 (0)