@@ -1791,13 +1791,19 @@ kernel void kernel_ssm_scan_f32_group(
1791
1791
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
1792
1792
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1793
1793
1794
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 ); // {1, nh}
1795
+ device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13 );
1796
+ device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22 );
1797
+ device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1 ))*args.nb41 + i3*args.nb43 );
1798
+ device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1 ))*args.nb51 + i3*args.nb53 );
1799
+ device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t *nh*nr))*nb00);
1800
+
1794
1801
for (int64_t i2 = 0 ; i2 < n_t ; ++i2) {
1795
- device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13 ); // {dim, nh, nt, ns}
1796
- device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22 ); // {nh, nt, ns}
1797
- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 ); // {1, nh}
1798
- device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1 ))*args.nb41 + i2*args.nb42 + i3*args.nb43 ); // {d_state, ng, nt, ns}
1799
- device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1 ))*args.nb51 + i2*args.nb52 + i3*args.nb53 ); // {d_state, ng, nt, ns}
1800
- device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t *nh*nr))*nb00); // {dim, nh, nt, ns}
1802
+ device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12 ); // {dim, nh, nt, ns}
1803
+ device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21 ); // {nh, nt, ns}
1804
+ device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42 ); // {d_state, ng, nt, ns}
1805
+ device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52 ); // {d_state, ng, nt, ns}
1806
+ device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
1801
1807
1802
1808
const float dt_soft_plus = dt[0 ] <= 20 .0f ? log (1 .0f + exp (dt[0 ])) : dt[0 ];
1803
1809
const float x_dt = x[0 ] * dt_soft_plus;
0 commit comments