Skip to content

Commit 0f0f28f

Browse files
author
alexju
committed
metal : refactor ssm_conv parameters into a struct
1 parent 46e1747 commit 0f0f28f

File tree

3 files changed

+48
-42
lines changed

3 files changed

+48
-42
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,4 +347,23 @@ typedef struct {
347347
int n_past;
348348
} ggml_metal_kargs_diag_mask_inf;
349349

350+
typedef struct {
351+
int64_t ne00;
352+
int64_t ne01;
353+
int64_t ne02;
354+
uint64_t nb00;
355+
uint64_t nb01;
356+
uint64_t nb02;
357+
int64_t ne10;
358+
int64_t ne11;
359+
uint64_t nb10;
360+
uint64_t nb11;
361+
int64_t ne0;
362+
int64_t ne1;
363+
int64_t ne2;
364+
uint64_t nb0;
365+
uint64_t nb1;
366+
uint64_t nb2;
367+
} ggml_metal_kargs_ssm_conv;
368+
350369
#endif // GGML_METAL_IMPL

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2089,27 +2089,30 @@ static void ggml_metal_encode_node(
20892089

20902090
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
20912091

2092-
// TODO: add ggml_metal_kargs struct
2092+
ggml_metal_kargs_ssm_conv args = {
2093+
/*.ne00 =*/ ne00,
2094+
/*.ne01 =*/ ne01,
2095+
/*.ne02 =*/ ne02,
2096+
/*.nb00 =*/ nb00,
2097+
/*.nb01 =*/ nb01,
2098+
/*.nb02 =*/ nb02,
2099+
/*.ne10 =*/ ne10,
2100+
/*.ne11 =*/ ne11,
2101+
/*.nb10 =*/ nb10,
2102+
/*.nb11 =*/ nb11,
2103+
/*.ne0 =*/ ne0,
2104+
/*.ne1 =*/ ne1,
2105+
/*.ne2 =*/ ne2,
2106+
/*.nb0 =*/ nb0,
2107+
/*.nb1 =*/ nb1,
2108+
/*.nb2 =*/ nb2,
2109+
};
2110+
20932111
[encoder setComputePipelineState:pipeline];
20942112
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
20952113
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
20962114
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2097-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
2098-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
2099-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
2100-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2101-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2102-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2103-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
2104-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
2105-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
2106-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
2107-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
2108-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
2109-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
2110-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
2111-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
2112-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
2115+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
21132116

21142117
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
21152118
} break;

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,43 +1214,27 @@ kernel void kernel_diag_mask_inf_8(
12141214
}
12151215

12161216
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
1217-
// TODO: optimize
12181217
kernel void kernel_ssm_conv_f32(
12191218
device const void * src0,
12201219
device const void * src1,
12211220
device float * dst,
1222-
constant int64_t & ne00,
1223-
constant int64_t & ne01,
1224-
constant int64_t & ne02,
1225-
constant uint64_t & nb00,
1226-
constant uint64_t & nb01,
1227-
constant uint64_t & nb02,
1228-
constant int64_t & ne10,
1229-
constant int64_t & ne11,
1230-
constant uint64_t & nb10,
1231-
constant uint64_t & nb11,
1232-
constant int64_t & ne0,
1233-
constant int64_t & ne1,
1234-
constant int64_t & ne2,
1235-
constant uint64_t & nb0,
1236-
constant uint64_t & nb1,
1237-
constant uint64_t & nb2,
1221+
constant ggml_metal_kargs_ssm_conv & args,
12381222
uint3 tgpig[[threadgroup_position_in_grid]],
12391223
uint3 tpitg[[thread_position_in_threadgroup]],
12401224
uint3 ntg[[threads_per_threadgroup]]) {
12411225
const int64_t ir = tgpig.x;
12421226
const int64_t i2 = tgpig.y;
12431227
const int64_t i3 = tgpig.z;
12441228

1245-
const int64_t nc = ne10;
1246-
//const int64_t ncs = ne00;
1247-
//const int64_t nr = ne01;
1248-
//const int64_t n_t = ne1;
1249-
//const int64_t n_s = ne2;
1229+
const int64_t nc = args.ne10;
1230+
//const int64_t ncs = args.ne00;
1231+
//const int64_t nr = args.ne01;
1232+
//const int64_t n_t = args.ne1;
1233+
//const int64_t n_s = args.ne2;
12501234

1251-
device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
1252-
device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
1253-
device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2);
1235+
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
1236+
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
1237+
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
12541238

12551239
float sumf = 0.0f;
12561240

0 commit comments

Comments
 (0)