@@ -1663,10 +1663,16 @@ kernel void kernel_ssm_conv_f32(
1663
1663
device const void * src0,
1664
1664
device const void * src1,
1665
1665
device float * dst,
1666
+ threadgroup float * shared [[threadgroup(0 )]],
1666
1667
constant ggml_metal_kargs_ssm_conv & args,
1667
- uint3 tgpig[[threadgroup_position_in_grid]],
1668
- uint3 tpitg[[thread_position_in_threadgroup]],
1669
- uint3 ntg[[threads_per_threadgroup]]) {
1668
+ uint3 tgpig[[threadgroup_position_in_grid]],
1669
+ uint3 tpitg[[thread_position_in_threadgroup]],
1670
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1671
+ ushort tiisg[[thread_index_in_simdgroup]],
1672
+ ushort sgptg[[simdgroups_per_threadgroup]],
1673
+ uint3 tgpg[[threadgroups_per_grid]]) {
1674
+
1675
+ const int64_t i0 = tpitg.x ;
1670
1676
const int64_t ir = tgpig.x ;
1671
1677
const int64_t i2 = tgpig.y ;
1672
1678
const int64_t i3 = tgpig.z ;
@@ -1681,13 +1687,31 @@ kernel void kernel_ssm_conv_f32(
1681
1687
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11 );
1682
1688
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2 );
1683
1689
1684
- float sumf = 0 . 0f ;
1690
+ float sumf = s[i0] * c[i0] ;
1685
1691
1686
- for ( int64_t i0 = 0 ; i0 < nc; ++i0) {
1687
- sumf += s[i0] * c[i0];
1688
- }
1692
+ // Parallel sum: first sum over threads in simd group, then sum over simd
1693
+ // group sums
1694
+ sumf = simd_sum (sumf);
1689
1695
1690
- x[0 ] = sumf;
1696
+ // If multiple simd groups per threadgroup, sum over simd group sums
1697
+ if (sgptg > 1 ) {
1698
+ if (tiisg == 0 ) {
1699
+ shared[sgitg] = sumf;
1700
+ }
1701
+ threadgroup_barrier (mem_flags::mem_threadgroup);
1702
+ sumf = 0 .0f ;
1703
+ if (sgitg == 0 ) {
1704
+ if (tiisg < sgptg) {
1705
+ sumf = shared[tiisg];
1706
+ }
1707
+ sumf = simd_sum (sumf);
1708
+ if (tiisg == 0 ) {
1709
+ x[0 ] = sumf;
1710
+ }
1711
+ }
1712
+ } else if (tiisg == 0 ) {
1713
+ x[0 ] = sumf;
1714
+ }
1691
1715
}
1692
1716
1693
1717
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
0 commit comments