@@ -1788,8 +1788,11 @@ kernel void kernel_ssm_scan_f32_group(
1788
1788
1789
1789
device const int32_t * ids = (device const int32_t *) src6;
1790
1790
1791
- device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
1792
- device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1791
+ device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
1792
+ device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1793
+ const int64_t i = tpitg.x + i1*nc;
1794
+ float s0 = s0_buff[i];
1795
+ float s = s_buff[i];
1793
1796
1794
1797
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 ); // {1, nh}
1795
1798
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13 );
@@ -1809,9 +1812,8 @@ kernel void kernel_ssm_scan_f32_group(
1809
1812
const float x_dt = x[0 ] * dt_soft_plus;
1810
1813
const float dA = exp (dt_soft_plus * A[0 ]);
1811
1814
1812
- const int64_t i = tpitg.x + i1*nc;
1813
- const float state = (s0[i] * dA) + (B[tpitg.x ] * x_dt);
1814
- s[i] = state;
1815
+ const float state = (s0 * dA) + (B[tpitg.x ] * x_dt);
1816
+ s = state;
1815
1817
1816
1818
// Parallel sum: This relies on the fact that this kernel will be
1817
1819
// dispatched with each threadgroup having (d_state, 1, 1) threads which
@@ -1851,6 +1853,9 @@ kernel void kernel_ssm_scan_f32_group(
1851
1853
// recurse
1852
1854
s0 = s;
1853
1855
}
1856
+
1857
+ // Assign the final state to the output buffer
1858
+ s_buff[i] = s;
1854
1859
}
1855
1860
1856
1861
kernel void kernel_rwkv_wkv6_f32 (
0 commit comments