Skip to content

Commit a5334f9

Browse files
committed
refactor: Compute block offsets once rather than once per token
Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 21db0b5 commit a5334f9

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,13 +1791,19 @@ kernel void kernel_ssm_scan_f32_group(
17911791
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
17921792
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
17931793

1794+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
1795+
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
1796+
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
1797+
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
1798+
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
1799+
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
1800+
17941801
for (int64_t i2 = 0; i2 < n_t; ++i2) {
1795-
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1796-
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
1797-
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
1798-
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
1799-
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
1800-
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
1802+
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
1803+
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
1804+
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
1805+
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
1806+
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
18011807

18021808
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
18031809
const float x_dt = x[0] * dt_soft_plus;

0 commit comments

Comments
 (0)