Skip to content

Commit 2cd68ad

Browse files
author
alexju
committed
metal : refactor pad_reflect_1d parameters into a struct
1 parent 0fcadad commit 2cd68ad

File tree

3 files changed

+53
-39
lines changed

3 files changed

+53
-39
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,4 +464,25 @@ typedef struct {
464464
uint64_t nb3;
465465
} ggml_metal_kargs_pad;
466466

467+
typedef struct {
468+
int64_t ne00;
469+
int64_t ne01;
470+
int64_t ne02;
471+
int64_t ne03;
472+
uint64_t nb00;
473+
uint64_t nb01;
474+
uint64_t nb02;
475+
uint64_t nb03;
476+
int64_t ne0;
477+
int64_t ne1;
478+
int64_t ne2;
479+
int64_t ne3;
480+
uint64_t nb0;
481+
uint64_t nb1;
482+
uint64_t nb2;
483+
uint64_t nb3;
484+
int32_t p0;
485+
int32_t p1;
486+
} ggml_metal_kargs_pad_reflect_1d;
487+
467488
#endif // GGML_METAL_IMPL

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3486,24 +3486,31 @@ static void ggml_metal_encode_node(
34863486

34873487
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
34883488

3489+
ggml_metal_kargs_pad_reflect_1d args = {
3490+
/*.ne00 =*/ ne00,
3491+
/*.ne01 =*/ ne01,
3492+
/*.ne02 =*/ ne02,
3493+
/*.ne03 =*/ ne03,
3494+
/*.nb00 =*/ nb00,
3495+
/*.nb01 =*/ nb01,
3496+
/*.nb02 =*/ nb02,
3497+
/*.nb03 =*/ nb03,
3498+
/*.ne0 =*/ ne0,
3499+
/*.ne1 =*/ ne1,
3500+
/*.ne2 =*/ ne2,
3501+
/*.ne3 =*/ ne3,
3502+
/*.nb0 =*/ nb0,
3503+
/*.nb1 =*/ nb1,
3504+
/*.nb2 =*/ nb2,
3505+
/*.nb3 =*/ nb3,
3506+
/*.p0 =*/ p0,
3507+
/*.p1 =*/ p1
3508+
};
3509+
34893510
[encoder setComputePipelineState:pipeline];
34903511
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
34913512
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3492-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
3493-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
3494-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
3495-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
3496-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
3497-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
3498-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
3499-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
3500-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
3501-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
3502-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
3503-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
3504-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
3505-
[encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
3506-
[encoder setBytes:&p1 length:sizeof(p1) atIndex:16];
3513+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
35073514

35083515
const int nth = MIN(1024, ne0);
35093516

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2723,21 +2723,7 @@ kernel void kernel_pad_f32(
27232723
kernel void kernel_pad_reflect_1d_f32(
27242724
device const char * src0,
27252725
device char * dst,
2726-
constant int64_t & ne00,
2727-
constant int64_t & ne01,
2728-
constant int64_t & ne02,
2729-
constant int64_t & ne03,
2730-
constant int64_t & ne0,
2731-
constant uint64_t & nb00,
2732-
constant uint64_t & nb01,
2733-
constant uint64_t & nb02,
2734-
constant uint64_t & nb03,
2735-
constant uint64_t & nb0,
2736-
constant uint64_t & nb1,
2737-
constant uint64_t & nb2,
2738-
constant uint64_t & nb3,
2739-
constant int32_t & p0,
2740-
constant int32_t & p1,
2726+
constant ggml_metal_kargs_pad_reflect_1d & args,
27412727
uint3 tgpig[[threadgroup_position_in_grid]],
27422728
uint3 tgpg[[threadgroups_per_grid]],
27432729
uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2751,17 +2737,17 @@ kernel void kernel_pad_reflect_1d_f32(
27512737
const int64_t i02 = i2;
27522738
const int64_t i01 = i1;
27532739

2754-
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
2755-
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
2740+
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
2741+
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
27562742

2757-
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
2758-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
2759-
if (i0 < p0) {
2760-
dst_ptr[i0] = src0_ptr[p0 - i0];
2761-
} else if (i0 < ne0 - p1) {
2762-
dst_ptr[i0] = src0_ptr[i0 - p0];
2743+
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
2744+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
2745+
if (i0 < args.p0) {
2746+
dst_ptr[i0] = src0_ptr[args.p0 - i0];
2747+
} else if (i0 < args.ne0 - args.p1) {
2748+
dst_ptr[i0] = src0_ptr[i0 - args.p0];
27632749
} else {
2764-
dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
2750+
dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
27652751
}
27662752
}
27672753
}

0 commit comments

Comments
 (0)