Skip to content

Commit b9b25e2

Browse files
committed
metal : ssm_Scan opt
1 parent fcebb3c commit b9b25e2

File tree

4 files changed

+70
-69
lines changed

4 files changed

+70
-69
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,10 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar
363363
char base[256];
364364
char name[256];
365365

366+
const int nsg = (ne00 + 31)/32;
367+
366368
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
367-
snprintf(name, 256, "%s", base);
369+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
368370

369371
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
370372
if (res) {
@@ -373,7 +375,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar
373375

374376
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
375377

376-
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
378+
ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);
377379

378380
return res;
379381
}

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,22 +573,30 @@ typedef struct {
573573
int64_t n_seq_tokens;
574574
int64_t n_seqs;
575575
uint64_t s_off;
576+
uint64_t nb00;
576577
uint64_t nb01;
577578
uint64_t nb02;
578579
uint64_t nb03;
580+
uint64_t nb10;
579581
uint64_t nb11;
580582
uint64_t nb12;
583+
uint64_t ns12;
581584
uint64_t nb13;
585+
uint64_t nb20;
582586
uint64_t nb21;
587+
uint64_t ns21;
583588
uint64_t nb22;
584589
int64_t ne30;
585590
uint64_t nb31;
586591
uint64_t nb41;
587592
uint64_t nb42;
593+
uint64_t ns42;
588594
uint64_t nb43;
589595
uint64_t nb51;
590596
uint64_t nb52;
597+
uint64_t ns52;
591598
uint64_t nb53;
599+
uint64_t nb0;
592600
} ggml_metal_kargs_ssm_scan;
593601

594602
typedef struct {

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,22 +1181,30 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
11811181
/*.n_seq_tokens =*/ n_seq_tokens,
11821182
/*.n_seqs =*/ n_seqs,
11831183
/*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float),
1184+
/*.nb00 =*/ nb00,
11841185
/*.nb01 =*/ nb01,
11851186
/*.nb02 =*/ nb02,
11861187
/*.nb03 =*/ nb03,
1188+
/*.nb10 =*/ nb10,
11871189
/*.nb11 =*/ nb11,
11881190
/*.nb12 =*/ nb12,
1191+
/*.ns12 =*/ nb12/nb10,
11891192
/*.nb13 =*/ nb13,
1193+
/*.nb20 =*/ nb20,
11901194
/*.nb21 =*/ nb21,
1195+
/*.ns21 =*/ nb21/nb20,
11911196
/*.nb22 =*/ nb22,
11921197
/*.ne30 =*/ ne30,
11931198
/*.nb31 =*/ nb31,
11941199
/*.nb41 =*/ nb41,
11951200
/*.nb42 =*/ nb42,
1201+
/*.ns42 =*/ nb42/nb40,
11961202
/*.nb43 =*/ nb43,
11971203
/*.nb51 =*/ nb51,
11981204
/*.nb52 =*/ nb52,
1205+
/*.ns52 =*/ nb52/nb50,
11991206
/*.nb53 =*/ nb53,
1207+
/*.nb0 =*/ nb0,
12001208
};
12011209

12021210
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 50 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -2075,105 +2075,88 @@ kernel void kernel_ssm_scan_f32(
20752075
device const void * src6,
20762076
device float * dst,
20772077
threadgroup float * shared [[threadgroup(0)]],
2078-
uint3 tgpig[[threadgroup_position_in_grid]],
2079-
uint3 tpitg[[thread_position_in_threadgroup]],
2080-
ushort sgitg[[simdgroup_index_in_threadgroup]],
2081-
ushort tiisg[[thread_index_in_simdgroup]],
2082-
ushort sgptg[[simdgroups_per_threadgroup]],
2083-
uint3 tgpg[[threadgroups_per_grid]]) {
2078+
uint3 tgpig[[threadgroup_position_in_grid]],
2079+
ushort3 tpitg[[thread_position_in_threadgroup]],
2080+
ushort sgitg[[simdgroup_index_in_threadgroup]],
2081+
ushort tiisg[[thread_index_in_simdgroup]],
2082+
ushort sgptg[[simdgroups_per_threadgroup]],
2083+
uint3 tgpg[[threadgroups_per_grid]]) {
2084+
constexpr short NW = N_SIMDWIDTH;
20842085

20852086
shared[tpitg.x] = 0.0f;
20862087

2087-
const int64_t i0 = tpitg.x;
2088-
const int64_t i1 = tgpig.x;
2089-
const int64_t ir = tgpig.y; // current head
2090-
const int64_t i3 = tgpig.z; // current seq
2091-
2092-
const uint64_t nb00 = sizeof(float);
2093-
const uint64_t nb10 = sizeof(float);
2094-
const uint64_t nb20 = sizeof(float);
2088+
const int32_t i0 = tpitg.x;
2089+
const int32_t i1 = tgpig.x;
2090+
const int32_t ir = tgpig.y; // current head
2091+
const int32_t i3 = tgpig.z; // current seq
20952092

2096-
const int64_t nc = args.d_state;
2097-
const int64_t nr = args.d_inner;
2098-
const int64_t nh = args.n_head;
2099-
const int64_t ng = args.n_group;
2100-
const int64_t n_t = args.n_seq_tokens;
2093+
const int32_t nc = args.d_state;
2094+
const int32_t nr = args.d_inner;
2095+
const int32_t nh = args.n_head;
2096+
const int32_t ng = args.n_group;
2097+
const int32_t n_t = args.n_seq_tokens;
21012098

2102-
const int64_t s_off = args.s_off;
2099+
const int32_t s_off = args.s_off;
21032100

21042101
device const int32_t * ids = (device const int32_t *) src6;
21052102

21062103
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
21072104
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
2108-
const int64_t i = i0 + i1*nc;
2109-
const int64_t g = ir / (nh / ng); // repeat_interleave
2105+
2106+
const int32_t i = i0 + i1*nc;
2107+
const int32_t g = ir / (nh / ng); // repeat_interleave
2108+
21102109
float s0 = s0_buff[i];
21112110
float s = 0.0f;
21122111

21132112
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
2113+
21142114
const float A0 = A[i0%args.ne30];
21152115

2116-
device const char * x_block = ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
2117-
device const char * dt_block = ((device const char *) src2 + ir*nb20 + i3*args.nb22);
2118-
device const char * B_block = ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
2119-
device const char * C_block = ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
2120-
device char * y_block = ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
2116+
device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
2117+
device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns}
2118+
device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns}
2119+
device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns}
21212120

2122-
threadgroup_barrier(mem_flags::mem_threadgroup);
2121+
device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
21232122

2124-
for (int64_t i2 = 0; i2 < n_t; ++i2) {
2125-
device const float * x = (device const float *) (x_block + i2*args.nb12); // {dim, nh, nt, ns}
2126-
device const float * dt = (device const float *) (dt_block + i2*args.nb21); // {nh, nt, ns}
2127-
device const float * B = (device const float *) (B_block + i2*args.nb42); // {d_state, ng, nt, ns}
2128-
device const float * C = (device const float *) (C_block + i2*args.nb52); // {d_state, ng, nt, ns}
2129-
device float * y = (device float *) (y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
2123+
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
2124+
threadgroup_barrier(mem_flags::mem_threadgroup);
21302125

2131-
const float dt0 = dt[0];
2132-
const float dt_soft_plus = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
2133-
const float x_dt = x[0] * dt_soft_plus;
2134-
const float dA = exp(dt_soft_plus * A0);
2126+
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
2127+
const float dt0 = dt[0];
2128+
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
2129+
const float x_dt = x[0] * dtsp;
2130+
const float dA = exp(dtsp * A0);
21352131

2136-
s = (s0 * dA) + (B[i0] * x_dt);
2132+
s = (s0 * dA) + (B[i0] * x_dt);
21372133

2138-
// Parallel sum: This relies on the fact that this kernel will be
2139-
// dispatched with each threadgroup having (d_state, 1, 1) threads which
2140-
// are subdivided into SIMD groups of size `sgptg`. The goal is to
2141-
// compute y = sum({state * C[i] for i in range(d_state)}).
2142-
// To parallelize this effectively, we first use simd_sum over each SIMD
2143-
// group to compute the sum of each SIMD group, then place the result in
2144-
// the SIMD group's indexed bucket in the shared memory. We then sum
2145-
// over the individual group sums to compute the final sum.
2134+
const float sumf = simd_sum(s * C[i0]);
21462135

2147-
// Computed for each thread
2148-
float sumf = s * C[i0];
2136+
if (tiisg == 0) {
2137+
shared[t*NW + sgitg] = sumf;
2138+
}
21492139

2150-
// Sum the threads in the simd group => simd sum
2151-
sumf = simd_sum(sumf);
2140+
// recurse
2141+
s0 = s;
21522142

2153-
// Once per simd group, place the group sum into the shared buffer
2154-
if (tiisg == 0) {
2155-
shared[sgitg] = sumf;
2143+
x += args.ns12;
2144+
dt += args.ns21;
2145+
B += args.ns42;
2146+
C += args.ns52;
21562147
}
21572148

2158-
// Wait for all threads in the threadgroup to reach this point. This
2159-
// ensures that all elements of the shared buffer are populated with the
2160-
// sum of the individual simd groups.
21612149
threadgroup_barrier(mem_flags::mem_threadgroup);
21622150

2163-
// For simd group 0 at indices < num simd groups, extract the shared
2164-
// simd sum
2165-
if (sgitg == 0) {
2166-
sumf = simd_sum(shared[tiisg]);
2167-
if (tiisg == 0) {
2168-
y[0] = sumf;
2169-
}
2151+
const float sumf = simd_sum(shared[sgitg*NW + tiisg]);
2152+
2153+
if (tiisg == 0 && i2 + sgitg < n_t) {
2154+
y[sgitg*nh*nr] = sumf;
21702155
}
21712156

2172-
// recurse
2173-
s0 = s;
2157+
y += sgptg*nh*nr;
21742158
}
21752159

2176-
// Assign the final state to the output buffer
21772160
s_buff[i] = s;
21782161
}
21792162

0 commit comments

Comments
 (0)