Skip to content

Commit 0fcadad

Browse files
author
alexju
committed
metal : refactor pad parameters into a struct
1 parent 619c157 commit 0fcadad

File tree

3 files changed

+46
-39
lines changed

3 files changed

+46
-39
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,4 +445,23 @@ typedef struct {
445445
float sf3;
446446
} ggml_metal_kargs_upscale;
447447

448+
typedef struct {
449+
int64_t ne00;
450+
int64_t ne01;
451+
int64_t ne02;
452+
int64_t ne03;
453+
uint64_t nb00;
454+
uint64_t nb01;
455+
uint64_t nb02;
456+
uint64_t nb03;
457+
int64_t ne0;
458+
int64_t ne1;
459+
int64_t ne2;
460+
int64_t ne3;
461+
uint64_t nb0;
462+
uint64_t nb1;
463+
uint64_t nb2;
464+
uint64_t nb3;
465+
} ggml_metal_kargs_pad;
466+
448467
#endif // GGML_METAL_IMPL

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3449,26 +3449,29 @@ static void ggml_metal_encode_node(
34493449

34503450
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
34513451

3452-
// TODO: add ggml_metal_kargs struct
3452+
ggml_metal_kargs_pad args = {
3453+
/*.ne00 =*/ ne00,
3454+
/*.ne01 =*/ ne01,
3455+
/*.ne02 =*/ ne02,
3456+
/*.ne03 =*/ ne03,
3457+
/*.nb00 =*/ nb00,
3458+
/*.nb01 =*/ nb01,
3459+
/*.nb02 =*/ nb02,
3460+
/*.nb03 =*/ nb03,
3461+
/*.ne0 =*/ ne0,
3462+
/*.ne1 =*/ ne1,
3463+
/*.ne2 =*/ ne2,
3464+
/*.ne3 =*/ ne3,
3465+
/*.nb0 =*/ nb0,
3466+
/*.nb1 =*/ nb1,
3467+
/*.nb2 =*/ nb2,
3468+
/*.nb3 =*/ nb3
3469+
};
3470+
34533471
[encoder setComputePipelineState:pipeline];
34543472
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
34553473
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3456-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
3457-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
3458-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
3459-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
3460-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
3461-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
3462-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
3463-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
3464-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
3465-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
3466-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
3467-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
3468-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
3469-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
3470-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
3471-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
3474+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
34723475

34733476
const int nth = MIN(1024, ne0);
34743477

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2687,22 +2687,7 @@ kernel void kernel_upscale_f32(
26872687
kernel void kernel_pad_f32(
26882688
device const char * src0,
26892689
device char * dst,
2690-
constant int64_t & ne00,
2691-
constant int64_t & ne01,
2692-
constant int64_t & ne02,
2693-
constant int64_t & ne03,
2694-
constant uint64_t & nb00,
2695-
constant uint64_t & nb01,
2696-
constant uint64_t & nb02,
2697-
constant uint64_t & nb03,
2698-
constant int64_t & ne0,
2699-
constant int64_t & ne1,
2700-
constant int64_t & ne2,
2701-
constant int64_t & ne3,
2702-
constant uint64_t & nb0,
2703-
constant uint64_t & nb1,
2704-
constant uint64_t & nb2,
2705-
constant uint64_t & nb3,
2690+
constant ggml_metal_kargs_pad & args,
27062691
uint3 tgpig[[threadgroup_position_in_grid]],
27072692
uint3 tpitg[[thread_position_in_threadgroup]],
27082693
uint3 ntg[[threads_per_threadgroup]]) {
@@ -2715,12 +2700,12 @@ kernel void kernel_pad_f32(
27152700
const int64_t i02 = i2;
27162701
const int64_t i01 = i1;
27172702

2718-
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
2719-
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
2703+
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
2704+
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
27202705

2721-
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
2722-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
2723-
if (i0 < ne00) {
2706+
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
2707+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
2708+
if (i0 < args.ne00) {
27242709
dst_ptr[i0] = src0_ptr[i0];
27252710
} else {
27262711
dst_ptr[i0] = 0.0f;
@@ -2730,7 +2715,7 @@ kernel void kernel_pad_f32(
27302715
return;
27312716
}
27322717

2733-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
2718+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
27342719
dst_ptr[i0] = 0.0f;
27352720
}
27362721
}

0 commit comments

Comments
 (0)