File tree Expand file tree Collapse file tree 1 file changed +10
-7
lines changed Expand file tree Collapse file tree 1 file changed +10
-7
lines changed Original file line number Diff line number Diff line change @@ -1840,16 +1840,19 @@ kernel void kernel_ssm_scan_f32_group(
1840
1840
// sum of the individual simd groups.
1841
1841
threadgroup_barrier (mem_flags::mem_threadgroup);
1842
1842
1843
- // Sum the simd buckets => threadgroup sum
1843
+ // For simd group 0 at indices < num simd groups, extract the shared
1844
+ // simd sum
1844
1845
sumf = 0 .0f ;
1845
- for (int64_t i0 = 0 ; i0 < sgptg; ++i0) {
1846
- sumf += shared[i0];
1846
+ if (sgitg == 0 ) {
1847
+ if (tiisg < sgptg) {
1848
+ sumf = shared[tiisg];
1849
+ }
1850
+ sumf = simd_sum (sumf);
1851
+ if (tiisg == 0 ) {
1852
+ y[0 ] = sumf;
1853
+ }
1847
1854
}
1848
1855
1849
- threadgroup_barrier (mem_flags::mem_threadgroup);
1850
-
1851
- y[0 ] = sumf;
1852
-
1853
1856
// recurse
1854
1857
s0 = s;
1855
1858
}
You can’t perform that action at this time.
0 commit comments