@@ -1663,16 +1663,10 @@ kernel void kernel_ssm_conv_f32(
1663
1663
device const void * src0,
1664
1664
device const void * src1,
1665
1665
device float * dst,
1666
- threadgroup float * shared [[threadgroup(0 )]],
1667
1666
constant ggml_metal_kargs_ssm_conv & args,
1668
- uint3 tgpig[[threadgroup_position_in_grid]],
1669
- uint3 tpitg[[thread_position_in_threadgroup]],
1670
- ushort sgitg[[simdgroup_index_in_threadgroup]],
1671
- ushort tiisg[[thread_index_in_simdgroup]],
1672
- ushort sgptg[[simdgroups_per_threadgroup]],
1673
- uint3 tgpg[[threadgroups_per_grid]]) {
1674
-
1675
- const int64_t i0 = tpitg.x ;
1667
+ uint3 tgpig[[threadgroup_position_in_grid]],
1668
+ uint3 tpitg[[thread_position_in_threadgroup]],
1669
+ uint3 ntg[[threads_per_threadgroup]]) {
1676
1670
const int64_t ir = tgpig.x ;
1677
1671
const int64_t i2 = tgpig.y ;
1678
1672
const int64_t i3 = tgpig.z ;
@@ -1687,31 +1681,13 @@ kernel void kernel_ssm_conv_f32(
1687
1681
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11 );
1688
1682
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2 );
1689
1683
1690
- float sumf = s[i0] * c[i0];
1691
-
1692
- // Parallel sum: first sum over threads in simd group, then sum over simd
1693
- // group sums
1694
- sumf = simd_sum (sumf);
1684
+ float sumf = 0 .0f ;
1695
1685
1696
- // If multiple simd groups per threadgroup, sum over simd group sums
1697
- if (sgptg > 1 ) {
1698
- if (tiisg == 0 ) {
1699
- shared[sgitg] = sumf;
1700
- }
1701
- threadgroup_barrier (mem_flags::mem_threadgroup);
1702
- sumf = 0 .0f ;
1703
- if (sgitg == 0 ) {
1704
- if (tiisg < sgptg) {
1705
- sumf = shared[tiisg];
1706
- }
1707
- sumf = simd_sum (sumf);
1708
- if (tiisg == 0 ) {
1709
- x[0 ] = sumf;
1710
- }
1711
- }
1712
- } else if (tiisg == 0 ) {
1713
- x[0 ] = sumf;
1686
+ for (int64_t i0 = 0 ; i0 < nc; ++i0) {
1687
+ sumf += s[i0] * c[i0];
1714
1688
}
1689
+
1690
+ x[0 ] = sumf;
1715
1691
}
1716
1692
1717
1693
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
0 commit comments