Skip to content

Commit 46e1747

Browse files
author
alexju
committed
metal : refactor diag_mask_inf parameters into a struct
1 parent dba23c7 commit 46e1747

File tree

3 files changed

+22
-17
lines changed

3 files changed

+22
-17
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,4 +341,10 @@ typedef struct {
341341
uint32_t n_head_log2;
342342
} ggml_metal_kargs_soft_max;
343343

344+
typedef struct {
345+
int64_t ne00;
346+
int64_t ne01;
347+
int n_past;
348+
} ggml_metal_kargs_diag_mask_inf;
349+
344350
#endif // GGML_METAL_IMPL

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2061,13 +2061,16 @@ static void ggml_metal_encode_node(
20612061
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
20622062
}
20632063

2064-
// TODO: add ggml_metal_kargs struct
2064+
ggml_metal_kargs_diag_mask_inf args = {
2065+
/*.ne00 =*/ ne00,
2066+
/*.ne01 =*/ ne01,
2067+
/*.n_past =*/ n_past,
2068+
};
2069+
20652070
[encoder setComputePipelineState:pipeline];
20662071
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
20672072
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2068-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2069-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2070-
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
2073+
[encoder setBytes:&args length:sizeof(args) atIndex:2];
20712074

20722075
if (ne00%8 == 0) {
20732076
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,43 +1175,39 @@ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kerne
11751175
kernel void kernel_diag_mask_inf(
11761176
device const float * src0,
11771177
device float * dst,
1178-
constant int64_t & ne00,
1179-
constant int64_t & ne01,
1180-
constant int & n_past,
1178+
constant ggml_metal_kargs_diag_mask_inf & args,
11811179
uint3 tpig[[thread_position_in_grid]]) {
11821180
const int64_t i02 = tpig[2];
11831181
const int64_t i01 = tpig[1];
11841182
const int64_t i00 = tpig[0];
11851183

1186-
if (i00 > n_past + i01) {
1187-
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
1184+
if (i00 > args.n_past + i01) {
1185+
dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = -INFINITY;
11881186
} else {
1189-
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
1187+
dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = src0[i02*args.ne01*args.ne00 + i01*args.ne00 + i00];
11901188
}
11911189
}
11921190

11931191
kernel void kernel_diag_mask_inf_8(
11941192
device const float4 * src0,
11951193
device float4 * dst,
1196-
constant int64_t & ne00,
1197-
constant int64_t & ne01,
1198-
constant int & n_past,
1194+
constant ggml_metal_kargs_diag_mask_inf & args,
11991195
uint3 tpig[[thread_position_in_grid]]) {
12001196

12011197
const int64_t i = 2*tpig[0];
12021198

12031199
dst[i+0] = src0[i+0];
12041200
dst[i+1] = src0[i+1];
12051201
int64_t i4 = 4*i;
1206-
const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
1207-
const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
1202+
const int64_t i02 = i4/(args.ne00*args.ne01); i4 -= i02*args.ne00*args.ne01;
1203+
const int64_t i01 = i4/(args.ne00); i4 -= i01*args.ne00;
12081204
const int64_t i00 = i4;
12091205
for (int k = 3; k >= 0; --k) {
1210-
if (i00 + 4 + k <= n_past + i01) {
1206+
if (i00 + 4 + k <= args.n_past + i01) {
12111207
break;
12121208
}
12131209
dst[i+1][k] = -INFINITY;
1214-
if (i00 + k > n_past + i01) {
1210+
if (i00 + k > args.n_past + i01) {
12151211
dst[i][k] = -INFINITY;
12161212
}
12171213
}

0 commit comments

Comments
 (0)