Skip to content

Commit dc0e6ce

Browse files
author
alexju
committed
metal : refactor conv_transpose_1d parameters into a struct
1 parent 08bad56 commit dc0e6ce

File tree

3 files changed

+30
-37
lines changed

3 files changed

+30
-37
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,15 @@ typedef struct {
296296
float eps;
297297
} ggml_metal_kargs_group_norm;
298298

299+
typedef struct {
300+
int32_t IC;
301+
int32_t IL;
302+
int32_t K;
303+
int32_t s0;
304+
uint64_t nb0;
305+
uint64_t nb1;
306+
} ggml_metal_kargs_conv_transpose_1d;
307+
299308
typedef struct {
300309
uint64_t ofs0;
301310
uint64_t ofs1;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3383,16 +3383,20 @@ static void ggml_metal_encode_node(
33833383
default: GGML_ABORT("fatal error");
33843384
};
33853385

3386+
ggml_metal_kargs_conv_transpose_1d args = {
3387+
/*.IC =*/ IC,
3388+
/*.IL =*/ IL,
3389+
/*.K =*/ K,
3390+
/*.s0 =*/ s0,
3391+
/*.nb0 =*/ nb0,
3392+
/*.nb1 =*/ nb1,
3393+
};
3394+
33863395
[encoder setComputePipelineState:pipeline];
33873396
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
33883397
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
33893398
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3390-
[encoder setBytes:&IC length:sizeof( int32_t) atIndex:3];
3391-
[encoder setBytes:&IL length:sizeof( int32_t) atIndex:4];
3392-
[encoder setBytes:&K length:sizeof( int32_t) atIndex:5];
3393-
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6];
3394-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7];
3395-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8];
3399+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
33963400

33973401
[encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
33983402
} break;

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2609,12 +2609,7 @@ typedef void (conv_transpose_1d_t)(
26092609
device const float * src0,
26102610
device const float * src1,
26112611
device char * dst,
2612-
constant int32_t & IC,
2613-
constant int32_t & IL,
2614-
constant int32_t & K,
2615-
constant int32_t & s0,
2616-
constant uint64_t & nb0,
2617-
constant uint64_t & nb1,
2612+
constant ggml_metal_kargs_conv_transpose_1d & args,
26182613
uint3 tgpig[[threadgroup_position_in_grid]],
26192614
uint3 tgpg[[threadgroups_per_grid]]);
26202615

@@ -2623,29 +2618,24 @@ kernel void kernel_conv_transpose_1d(
26232618
device const T * src0,
26242619
device const float * src1,
26252620
device char * dst,
2626-
constant int32_t & IC,
2627-
constant int32_t & IL,
2628-
constant int32_t & K,
2629-
constant int32_t & s0,
2630-
constant uint64_t & nb0,
2631-
constant uint64_t & nb1,
2621+
constant ggml_metal_kargs_conv_transpose_1d & args,
26322622
uint3 tgpig[[threadgroup_position_in_grid]],
26332623
uint3 tgpg[[threadgroups_per_grid]]) {
26342624

26352625
float v = 0.0f;
26362626

2637-
for (int64_t c = 0; c < IC; c++) {
2638-
const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
2639-
const int32_t input_offset = c * IL;
2627+
for (int64_t c = 0; c < args.IC; c++) {
2628+
const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1];
2629+
const int32_t input_offset = c * args.IL;
26402630

2641-
for (int64_t i = 0; i < IL; i++) {
2642-
if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) {
2643-
v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i];
2631+
for (int64_t i = 0; i < args.IL; i++) {
2632+
if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) {
2633+
v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
26442634
}
26452635
}
26462636
}
26472637

2648-
device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1);
2638+
device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
26492639

26502640
dst_ptr[0] = v;
26512641
}
@@ -2655,12 +2645,7 @@ kernel void kernel_conv_transpose_1d<float>(
26552645
device const float * src0,
26562646
device const float * src1,
26572647
device char * dst,
2658-
constant int32_t & IC,
2659-
constant int32_t & IL,
2660-
constant int32_t & K,
2661-
constant int32_t & s0,
2662-
constant uint64_t & nb0,
2663-
constant uint64_t & nb1,
2648+
constant ggml_metal_kargs_conv_transpose_1d & args,
26642649
uint3 tgpig[[threadgroup_position_in_grid]],
26652650
uint3 tgpg[[threadgroups_per_grid]]);
26662651

@@ -2669,12 +2654,7 @@ kernel void kernel_conv_transpose_1d<half>(
26692654
device const half * src0,
26702655
device const float * src1,
26712656
device char * dst,
2672-
constant int32_t & IC,
2673-
constant int32_t & IL,
2674-
constant int32_t & K,
2675-
constant int32_t & s0,
2676-
constant uint64_t & nb0,
2677-
constant uint64_t & nb1,
2657+
constant ggml_metal_kargs_conv_transpose_1d & args,
26782658
uint3 tgpig[[threadgroup_position_in_grid]],
26792659
uint3 tgpg[[threadgroups_per_grid]]);
26802660

0 commit comments

Comments
 (0)