Skip to content

Commit 678b2e7

Browse files
author
alexju
committed
metal : refactor ssm_scan parameters into a struct
1 parent 0f0f28f commit 678b2e7

File tree

3 files changed

+64
-60
lines changed

3 files changed

+64
-60
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,4 +366,29 @@ typedef struct {
366366
uint64_t nb2;
367367
} ggml_metal_kargs_ssm_conv;
368368

369+
typedef struct {
370+
int64_t d_state;
371+
int64_t d_inner;
372+
int64_t n_seq_tokens;
373+
int64_t n_seqs;
374+
uint64_t nb00;
375+
uint64_t nb01;
376+
uint64_t nb02;
377+
uint64_t nb10;
378+
uint64_t nb11;
379+
uint64_t nb12;
380+
uint64_t nb13;
381+
uint64_t nb20;
382+
uint64_t nb21;
383+
uint64_t nb22;
384+
uint64_t nb30;
385+
uint64_t nb31;
386+
uint64_t nb40;
387+
uint64_t nb41;
388+
uint64_t nb42;
389+
uint64_t nb50;
390+
uint64_t nb51;
391+
uint64_t nb52;
392+
} ggml_metal_kargs_ssm_scan;
393+
369394
#endif // GGML_METAL_IMPL

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2163,7 +2163,31 @@ static void ggml_metal_encode_node(
21632163

21642164
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
21652165

2166-
// TODO: add ggml_metal_kargs struct
2166+
ggml_metal_kargs_ssm_scan args = {
2167+
/*.d_state =*/ d_state,
2168+
/*.d_inner =*/ d_inner,
2169+
/*.n_seq_tokens =*/ n_seq_tokens,
2170+
/*.n_seqs =*/ n_seqs,
2171+
/*.nb00 =*/ nb00,
2172+
/*.nb01 =*/ nb01,
2173+
/*.nb02 =*/ nb02,
2174+
/*.nb10 =*/ nb10,
2175+
/*.nb11 =*/ nb11,
2176+
/*.nb12 =*/ nb12,
2177+
/*.nb13 =*/ nb13,
2178+
/*.nb20 =*/ nb20,
2179+
/*.nb21 =*/ nb21,
2180+
/*.nb22 =*/ nb22,
2181+
/*.nb30 =*/ nb30,
2182+
/*.nb31 =*/ nb31,
2183+
/*.nb40 =*/ nb40,
2184+
/*.nb41 =*/ nb41,
2185+
/*.nb42 =*/ nb42,
2186+
/*.nb50 =*/ nb50,
2187+
/*.nb51 =*/ nb51,
2188+
/*.nb52 =*/ nb52,
2189+
};
2190+
21672191
[encoder setComputePipelineState:pipeline];
21682192
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
21692193
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2172,30 +2196,7 @@ static void ggml_metal_encode_node(
21722196
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
21732197
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
21742198
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
2175-
2176-
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
2177-
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
2178-
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
2179-
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
2180-
2181-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
2182-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
2183-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
2184-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
2185-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
2186-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
2187-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
2188-
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
2189-
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
2190-
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
2191-
[encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
2192-
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
2193-
[encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
2194-
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
2195-
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
2196-
[encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
2197-
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
2198-
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
2199+
[encoder setBytes:&args length:sizeof(args) atIndex:7];
21992200

22002201
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
22012202
} break;

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,7 +1246,6 @@ kernel void kernel_ssm_conv_f32(
12461246
}
12471247

12481248
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
1249-
// TODO: optimize
12501249
kernel void kernel_ssm_scan_f32(
12511250
device const void * src0,
12521251
device const void * src1,
@@ -1255,48 +1254,27 @@ kernel void kernel_ssm_scan_f32(
12551254
device const void * src4,
12561255
device const void * src5,
12571256
device float * dst,
1258-
constant int64_t & d_state,
1259-
constant int64_t & d_inner,
1260-
constant int64_t & n_seq_tokens,
1261-
constant int64_t & n_seqs,
1262-
constant uint64_t & nb00,
1263-
constant uint64_t & nb01,
1264-
constant uint64_t & nb02,
1265-
constant uint64_t & nb10,
1266-
constant uint64_t & nb11,
1267-
constant uint64_t & nb12,
1268-
constant uint64_t & nb13,
1269-
constant uint64_t & nb20,
1270-
constant uint64_t & nb21,
1271-
constant uint64_t & nb22,
1272-
constant uint64_t & nb30,
1273-
constant uint64_t & nb31,
1274-
constant uint64_t & nb40,
1275-
constant uint64_t & nb41,
1276-
constant uint64_t & nb42,
1277-
constant uint64_t & nb50,
1278-
constant uint64_t & nb51,
1279-
constant uint64_t & nb52,
1257+
constant ggml_metal_kargs_ssm_scan & args,
12801258
uint3 tgpig[[threadgroup_position_in_grid]],
12811259
uint3 tpitg[[thread_position_in_threadgroup]],
12821260
uint3 ntg[[threads_per_threadgroup]]) {
12831261
const int64_t ir = tgpig.x;
12841262
const int64_t i3 = tgpig.y;
12851263

1286-
const int64_t nc = d_state;
1287-
//const int64_t nr = d_inner;
1288-
const int64_t n_t = n_seq_tokens;
1289-
//const int64_t n_s = n_seqs;
1264+
const int64_t nc = args.d_state;
1265+
// const int64_t nr = args.d_inner;
1266+
const int64_t n_t = args.n_seq_tokens;
1267+
// const int64_t n_s = args.n_seqs;
12901268

12911269
for (int64_t i2 = 0; i2 < n_t; ++i2) {
1292-
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
1293-
device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
1294-
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
1295-
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
1296-
device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
1297-
device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
1298-
device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
1299-
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);
1270+
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
1271+
device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
1272+
device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
1273+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
1274+
device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
1275+
device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
1276+
device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
1277+
device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13);
13001278

13011279
if (i2 > 0) {
13021280
s0 = s;

0 commit comments

Comments
 (0)