Skip to content

Commit e55176a

Browse files
committed
Revert "feat: Parallel sum in SSM_CONV"
After discussion with @compilade, the size of the parallelism here is not worth the cost in complexity or overhead of the parallel for. ggml-org#14743 (comment) This reverts commit 16bc059. Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 16bc059 commit e55176a

File tree

2 files changed

+9
-52
lines changed

2 files changed

+9
-52
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2909,26 +2909,7 @@ static bool ggml_metal_encode_node(
29092909
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
29102910
[encoder setBytes:&args length:sizeof(args) atIndex:3];
29112911

2912-
const int64_t d_state = ne10;
2913-
2914-
// One shared memory bucket for each simd group in the threadgroup
2915-
if (d_state >= 32) {
2916-
const int64_t shmem_size = d_state / 32;
2917-
2918-
// The final simd_sum won't work if the number of simd groups is
2919-
// larger than the size of a single simd group. If this case is
2920-
// hit at some point, the logic in the second simd_sum could be
2921-
// expanded to handle this with one more sequential simd_sum to
2922-
// collapse simd group sums another time.
2923-
GGML_ASSERT(shmem_size <= 32);
2924-
2925-
// One thread pre element in d_state
2926-
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
2927-
2928-
[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
2929-
}
2930-
2931-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
2912+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
29322913
} break;
29332914
case GGML_OP_SSM_SCAN:
29342915
{

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,16 +1663,10 @@ kernel void kernel_ssm_conv_f32(
16631663
device const void * src0,
16641664
device const void * src1,
16651665
device float * dst,
1666-
threadgroup float * shared [[threadgroup(0)]],
16671666
constant ggml_metal_kargs_ssm_conv & args,
1668-
uint3 tgpig[[threadgroup_position_in_grid]],
1669-
uint3 tpitg[[thread_position_in_threadgroup]],
1670-
ushort sgitg[[simdgroup_index_in_threadgroup]],
1671-
ushort tiisg[[thread_index_in_simdgroup]],
1672-
ushort sgptg[[simdgroups_per_threadgroup]],
1673-
uint3 tgpg[[threadgroups_per_grid]]) {
1674-
1675-
const int64_t i0 = tpitg.x;
1667+
uint3 tgpig[[threadgroup_position_in_grid]],
1668+
uint3 tpitg[[thread_position_in_threadgroup]],
1669+
uint3 ntg[[threads_per_threadgroup]]) {
16761670
const int64_t ir = tgpig.x;
16771671
const int64_t i2 = tgpig.y;
16781672
const int64_t i3 = tgpig.z;
@@ -1687,31 +1681,13 @@ kernel void kernel_ssm_conv_f32(
16871681
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
16881682
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
16891683

1690-
float sumf = s[i0] * c[i0];
1691-
1692-
// Parallel sum: first sum over threads in simd group, then sum over simd
1693-
// group sums
1694-
sumf = simd_sum(sumf);
1684+
float sumf = 0.0f;
16951685

1696-
// If multiple simd groups per threadgroup, sum over simd group sums
1697-
if (sgptg > 1) {
1698-
if (tiisg == 0) {
1699-
shared[sgitg] = sumf;
1700-
}
1701-
threadgroup_barrier(mem_flags::mem_threadgroup);
1702-
sumf = 0.0f;
1703-
if (sgitg == 0) {
1704-
if (tiisg < sgptg) {
1705-
sumf = shared[tiisg];
1706-
}
1707-
sumf = simd_sum(sumf);
1708-
if (tiisg == 0) {
1709-
x[0] = sumf;
1710-
}
1711-
}
1712-
} else if (tiisg == 0) {
1713-
x[0] = sumf;
1686+
for (int64_t i0 = 0; i0 < nc; ++i0) {
1687+
sumf += s[i0] * c[i0];
17141688
}
1689+
1690+
x[0] = sumf;
17151691
}
17161692

17171693
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part

0 commit comments

Comments
 (0)