Skip to content

Commit 80545ef

Browse files
committed
feat: Parallelize of d_state for mamba-1
Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <[email protected]>
1 parent d06d087 commit 80545ef

File tree

2 files changed

+85
-29
lines changed

2 files changed

+85
-29
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3015,12 +3015,9 @@ static bool ggml_metal_encode_node(
30153015
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
30163016
[encoder setBytes:&args length:sizeof(args) atIndex:8];
30173017

3018-
if (ne30 == 1) {
3019-
// Mamba-2
3020-
3021-
// One shared memory bucket for each simd group in the threadgroup
3018+
// One shared memory bucket for each simd group in the threadgroup
3019+
if (d_state >= 32) {
30223020
const int64_t shmem_size = d_state / 32;
3023-
GGML_ASSERT(shmem_size * 32 == d_state);
30243021

30253022
// The final simd_sum won't work if the number of simd groups is
30263023
// larger than the size of a single simd group. If this case is
@@ -3033,10 +3030,14 @@ static bool ggml_metal_encode_node(
30333030
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
30343031

30353032
[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
3033+
}
3034+
3035+
if (ne30 == 1) {
3036+
// Mamba-2
30363037
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
30373038
} else {
30383039
GGML_ASSERT(d_inner == 1);
3039-
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3040+
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
30403041
}
30413042
} break;
30423043
case GGML_OP_RWKV_WKV6:

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 78 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,10 +1700,16 @@ kernel void kernel_ssm_scan_f32(
17001700
device const void * src5,
17011701
device const void * src6,
17021702
device float * dst,
1703+
threadgroup float * shared [[threadgroup(0)]],
17031704
constant ggml_metal_kargs_ssm_scan & args,
1704-
uint3 tgpig[[threadgroup_position_in_grid]],
1705-
uint3 tpitg[[thread_position_in_threadgroup]],
1706-
uint3 ntg[[threads_per_threadgroup]]) {
1705+
uint3 tgpig[[threadgroup_position_in_grid]],
1706+
uint3 tpitg[[thread_position_in_threadgroup]],
1707+
ushort sgitg[[simdgroup_index_in_threadgroup]],
1708+
ushort tiisg[[thread_index_in_simdgroup]],
1709+
ushort sgptg[[simdgroups_per_threadgroup]],
1710+
uint3 tgpg[[threadgroups_per_grid]]) {
1711+
1712+
const int64_t i0 = tpitg.x;
17071713
const int64_t i1 = 0;
17081714
const int64_t ir = tgpig.x; // current head
17091715
const int64_t i3 = tgpig.y; // current seq
@@ -1718,37 +1724,85 @@ kernel void kernel_ssm_scan_f32(
17181724
const int64_t ng = args.n_group;
17191725
const int64_t n_t = args.n_seq_tokens;
17201726

1721-
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
1727+
const int64_t s_off = args.s_off;
17221728

17231729
device const int32_t * ids = (device const int32_t *) src6;
17241730

1725-
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1726-
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1731+
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1732+
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1733+
const int64_t i = i0 + i1*nc;
1734+
float s0 = s0_buff[i];
1735+
float s = s_buff[i];
1736+
1737+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
1738+
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
1739+
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
1740+
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
1741+
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
1742+
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
17271743

17281744
for (int64_t i2 = 0; i2 < n_t; ++i2) {
1729-
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1730-
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
1731-
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
1732-
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
1733-
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
1734-
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
1745+
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
1746+
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
1747+
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
1748+
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
1749+
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
17351750

17361751
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
17371752
const float x_dt = x[0] * dt_soft_plus;
1738-
float sumf = 0.0f;
17391753

1740-
for (int64_t i0 = 0; i0 < nc; ++i0) {
1741-
const int64_t i = i0 + i1*nc;
1742-
const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
1743-
sumf += state * C[i0];
1744-
s[i] = state;
1745-
}
1754+
const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
1755+
s = state;
1756+
1757+
// Parallel sum: This relies on the fact that this kernel will be
1758+
// dispatched with each threadgroup having (d_state, 1, 1) threads which
1759+
// are subdivided into SIMD groups of size `sgptg`. The goal is to
1760+
// compute y = sum({state * C[i] for i in range(d_state)}).
1761+
// To parallelize this effectively, we first use simd_sum over each SIMD
1762+
// group to compute the sum of each SIMD group, then place the result in
1763+
// the SIMD group's indexed bucket in the shared memory. We then sum
1764+
// over the individual group sums to compute the final sum.
1765+
1766+
// Computed for each thread
1767+
float sumf = state * C[i0];
1768+
1769+
// Sum the threads in the simd group => simd sum
1770+
sumf = simd_sum(sumf);
17461771

1747-
y[0] = sumf;
1772+
if (sgptg > 1) {
1773+
1774+
// Once per simd group, place the group sum into the shared buffer
1775+
if (tiisg == 0) {
1776+
shared[sgitg] = sumf;
1777+
}
1778+
1779+
// Wait for all threads in the threadgroup to reach this point. This
1780+
// ensures that all elements of the shared buffer are populated with the
1781+
// sum of the individual simd groups.
1782+
threadgroup_barrier(mem_flags::mem_threadgroup);
1783+
1784+
// For simd group 0 at indices < num simd groups, extract the shared
1785+
// simd sum
1786+
sumf = 0.0f;
1787+
if (sgitg == 0) {
1788+
if (tiisg < sgptg) {
1789+
sumf = shared[tiisg];
1790+
}
1791+
sumf = simd_sum(sumf);
1792+
if (tiisg == 0) {
1793+
y[0] = sumf;
1794+
}
1795+
}
1796+
} else if (tiisg == 0) {
1797+
y[0] = sumf;
1798+
}
17481799

17491800
// recurse
17501801
s0 = s;
17511802
}
1803+
1804+
// Assign the final state to the output buffer
1805+
s_buff[i] = s;
17521806
}
17531807

17541808
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
@@ -1770,6 +1824,7 @@ kernel void kernel_ssm_scan_f32_group(
17701824
ushort sgptg[[simdgroups_per_threadgroup]],
17711825
uint3 tgpg[[threadgroups_per_grid]]) {
17721826

1827+
const int64_t i0 = tpitg.x;
17731828
const int64_t i1 = tgpig.x;
17741829
const int64_t ir = tgpig.y; // current head
17751830
const int64_t i3 = tgpig.z; // current seq
@@ -1790,7 +1845,7 @@ kernel void kernel_ssm_scan_f32_group(
17901845

17911846
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
17921847
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1793-
const int64_t i = tpitg.x + i1*nc;
1848+
const int64_t i = i0 + i1*nc;
17941849
float s0 = s0_buff[i];
17951850
float s = s_buff[i];
17961851

@@ -1812,7 +1867,7 @@ kernel void kernel_ssm_scan_f32_group(
18121867
const float x_dt = x[0] * dt_soft_plus;
18131868
const float dA = exp(dt_soft_plus * A[0]);
18141869

1815-
const float state = (s0 * dA) + (B[tpitg.x] * x_dt);
1870+
const float state = (s0 * dA) + (B[i0] * x_dt);
18161871
s = state;
18171872

18181873
// Parallel sum: This relies on the fact that this kernel will be
@@ -1825,7 +1880,7 @@ kernel void kernel_ssm_scan_f32_group(
18251880
// over the individual group sums to compute the final sum.
18261881

18271882
// Computed for each thread
1828-
float sumf = state * C[tpitg.x];
1883+
float sumf = state * C[i0];
18291884

18301885
// Sum the threads in the simd group => simd sum
18311886
sumf = simd_sum(sumf);

0 commit comments

Comments
 (0)