Skip to content

Commit fcebb3c

Browse files
committed
metal : ssm_scan simplify
1 parent 9954577 commit fcebb3c

File tree

4 files changed

+30
-145
lines changed

4 files changed

+30
-145
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,14 +358,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
358358
}
359359

360360
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
361+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
362+
361363
char base[256];
362364
char name[256];
363365

364-
if (op->src[3]->ne[0] == 1) {
365-
snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type));
366-
} else {
367-
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
368-
}
366+
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
369367
snprintf(name, 256, "%s", base);
370368

371369
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,7 @@ typedef struct {
581581
uint64_t nb13;
582582
uint64_t nb21;
583583
uint64_t nb22;
584+
int64_t ne30;
584585
uint64_t nb31;
585586
uint64_t nb41;
586587
uint64_t nb42;

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
11891189
/*.nb13 =*/ nb13,
11901190
/*.nb21 =*/ nb21,
11911191
/*.nb22 =*/ nb22,
1192+
/*.ne30 =*/ ne30,
11921193
/*.nb31 =*/ nb31,
11931194
/*.nb41 =*/ nb41,
11941195
/*.nb42 =*/ nb42,
@@ -1200,6 +1201,8 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
12001201

12011202
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
12021203

1204+
GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1205+
12031206
const size_t sms = ggml_metal_pipeline_get_smem(pipeline);
12041207

12051208
ggml_metal_encoder_set_pipeline(enc, pipeline);
@@ -1215,13 +1218,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
12151218

12161219
ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
12171220

1218-
if (ne30 == 1) {
1219-
// Mamba-2
1220-
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
1221-
} else {
1222-
GGML_ASSERT(d_inner == 1);
1223-
ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
1224-
}
1221+
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
12251222

12261223
return 1;
12271224
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 22 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -2063,7 +2063,7 @@ kernel void kernel_ssm_conv_f32_f32_4(
20632063
x[0] = sumf;
20642064
}
20652065

2066-
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
2066+
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
20672067
kernel void kernel_ssm_scan_f32(
20682068
constant ggml_metal_kargs_ssm_scan & args,
20692069
device const void * src0,
@@ -2081,120 +2081,8 @@ kernel void kernel_ssm_scan_f32(
20812081
ushort tiisg[[thread_index_in_simdgroup]],
20822082
ushort sgptg[[simdgroups_per_threadgroup]],
20832083
uint3 tgpg[[threadgroups_per_grid]]) {
2084-
const int64_t i0 = tpitg.x;
2085-
const int64_t i1 = 0;
2086-
const int64_t ir = tgpig.x; // current head
2087-
const int64_t i3 = tgpig.y; // current seq
2088-
2089-
const uint64_t nb00 = sizeof(float);
2090-
const uint64_t nb10 = sizeof(float);
2091-
const uint64_t nb20 = sizeof(float);
2092-
2093-
const int64_t nc = args.d_state;
2094-
const int64_t nr = args.d_inner;
2095-
const int64_t nh = args.n_head;
2096-
const int64_t ng = args.n_group;
2097-
const int64_t n_t = args.n_seq_tokens;
2098-
2099-
const int64_t s_off = args.s_off;
21002084

2101-
device const int32_t * ids = (device const int32_t *) src6;
2102-
2103-
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
2104-
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
2105-
const int64_t i = i0 + i1*nc;
2106-
const int64_t g = ir / (nh / ng); // repeat_interleave
2107-
float s0 = s0_buff[i];
2108-
float s = 0.0f;
2109-
2110-
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
2111-
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
2112-
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
2113-
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
2114-
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
2115-
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
2116-
2117-
for (int64_t i2 = 0; i2 < n_t; ++i2) {
2118-
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
2119-
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
2120-
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
2121-
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
2122-
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
2123-
2124-
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
2125-
const float x_dt = x[0] * dt_soft_plus;
2126-
2127-
s = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
2128-
2129-
// Parallel sum: This relies on the fact that this kernel will be
2130-
// dispatched with each threadgroup having (d_state, 1, 1) threads which
2131-
// are subdivided into SIMD groups of size `sgptg`. The goal is to
2132-
// compute y = sum({state * C[i] for i in range(d_state)}).
2133-
// To parallelize this effectively, we first use simd_sum over each SIMD
2134-
// group to compute the sum of each SIMD group, then place the result in
2135-
// the SIMD group's indexed bucket in the shared memory. We then sum
2136-
// over the individual group sums to compute the final sum.
2137-
2138-
// Computed for each thread
2139-
float sumf = s * C[i0];
2140-
2141-
// Sum the threads in the simd group => simd sum
2142-
sumf = simd_sum(sumf);
2143-
2144-
if (sgptg > 1) {
2145-
2146-
// Once per simd group, place the group sum into the shared buffer
2147-
if (tiisg == 0) {
2148-
shared[sgitg] = sumf;
2149-
}
2150-
2151-
// Wait for all threads in the threadgroup to reach this point. This
2152-
// ensures that all elements of the shared buffer are populated with the
2153-
// sum of the individual simd groups.
2154-
threadgroup_barrier(mem_flags::mem_threadgroup);
2155-
2156-
// For simd group 0 at indices < num simd groups, extract the shared
2157-
// simd sum
2158-
sumf = 0.0f;
2159-
if (sgitg == 0) {
2160-
if (tiisg < sgptg) {
2161-
sumf = shared[tiisg];
2162-
}
2163-
sumf = simd_sum(sumf);
2164-
if (tiisg == 0) {
2165-
y[0] = sumf;
2166-
}
2167-
}
2168-
} else if (tiisg == 0) {
2169-
y[0] = sumf;
2170-
}
2171-
2172-
// recurse
2173-
s0 = s;
2174-
}
2175-
2176-
// Assign the final state to the output buffer
2177-
s_buff[i] = s;
2178-
}
2179-
2180-
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
2181-
kernel void kernel_ssm_scan_group_f32(
2182-
constant ggml_metal_kargs_ssm_scan & args,
2183-
device const void * src0,
2184-
device const void * src1,
2185-
device const void * src2,
2186-
device const void * src3,
2187-
device const void * src4,
2188-
device const void * src5,
2189-
device const void * src6,
2190-
device float * dst,
2191-
threadgroup float * shared [[threadgroup(0)]],
2192-
uint3 tgpig[[threadgroup_position_in_grid]],
2193-
uint3 tpitg[[thread_position_in_threadgroup]],
2194-
ushort sgitg[[simdgroup_index_in_threadgroup]],
2195-
ushort tiisg[[thread_index_in_simdgroup]],
2196-
ushort sgptg[[simdgroups_per_threadgroup]],
2197-
uint3 tgpg[[threadgroups_per_grid]]) {
2085+
shared[tpitg.x] = 0.0f;
21982086

21992087
const int64_t i0 = tpitg.x;
22002088
const int64_t i1 = tgpig.x;
@@ -2222,23 +2110,28 @@ kernel void kernel_ssm_scan_group_f32(
22222110
float s0 = s0_buff[i];
22232111
float s = 0.0f;
22242112

2225-
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
2226-
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
2227-
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
2228-
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
2229-
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
2230-
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
2113+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
2114+
const float A0 = A[i0%args.ne30];
22312115

2232-
for (int64_t i2 = 0; i2 < n_t; ++i2) {
2233-
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
2234-
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
2235-
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
2236-
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
2237-
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
2116+
device const char * x_block = ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
2117+
device const char * dt_block = ((device const char *) src2 + ir*nb20 + i3*args.nb22);
2118+
device const char * B_block = ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
2119+
device const char * C_block = ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
2120+
device char * y_block = ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
2121+
2122+
threadgroup_barrier(mem_flags::mem_threadgroup);
22382123

2239-
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
2124+
for (int64_t i2 = 0; i2 < n_t; ++i2) {
2125+
device const float * x = (device const float *) (x_block + i2*args.nb12); // {dim, nh, nt, ns}
2126+
device const float * dt = (device const float *) (dt_block + i2*args.nb21); // {nh, nt, ns}
2127+
device const float * B = (device const float *) (B_block + i2*args.nb42); // {d_state, ng, nt, ns}
2128+
device const float * C = (device const float *) (C_block + i2*args.nb52); // {d_state, ng, nt, ns}
2129+
device float * y = (device float *) (y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
2130+
2131+
const float dt0 = dt[0];
2132+
const float dt_soft_plus = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
22402133
const float x_dt = x[0] * dt_soft_plus;
2241-
const float dA = exp(dt_soft_plus * A[0]);
2134+
const float dA = exp(dt_soft_plus * A0);
22422135

22432136
s = (s0 * dA) + (B[i0] * x_dt);
22442137

@@ -2269,12 +2162,8 @@ kernel void kernel_ssm_scan_group_f32(
22692162

22702163
// For simd group 0 at indices < num simd groups, extract the shared
22712164
// simd sum
2272-
sumf = 0.0f;
22732165
if (sgitg == 0) {
2274-
if (tiisg < sgptg) {
2275-
sumf = shared[tiisg];
2276-
}
2277-
sumf = simd_sum(sumf);
2166+
sumf = simd_sum(shared[tiisg]);
22782167
if (tiisg == 0) {
22792168
y[0] = sumf;
22802169
}

0 commit comments

Comments
 (0)