Skip to content

Commit 5b8ec2b

Browse files
committed
metal : fix SSM_SCAN state head offset
1 parent 8b15bc6 commit 5b8ec2b

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

ggml/src/ggml-metal.metal

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -850,8 +850,8 @@ kernel void kernel_ssm_scan_f32(
850850

851851
device const int32_t * ids = (device const int32_t *) src7;
852852

853-
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03);
854-
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off);
853+
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03);
854+
device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off);
855855

856856
for (int64_t i2 = 0; i2 < n_t; ++i2) {
857857
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns}
@@ -935,8 +935,8 @@ kernel void kernel_ssm_scan_f32_group(
935935

936936
device const int32_t * ids = (device const int32_t *) src7;
937937

938-
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03);
939-
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off);
938+
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03);
939+
device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off);
940940

941941
for (int64_t i2 = 0; i2 < n_t; ++i2) {
942942
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns}

0 commit comments

Comments
 (0)