@@ -1246,7 +1246,6 @@ kernel void kernel_ssm_conv_f32(
12461246}
12471247
12481248// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
1249- // TODO: optimize
12501249kernel void kernel_ssm_scan_f32 (
12511250 device const void * src0,
12521251 device const void * src1,
@@ -1255,48 +1254,27 @@ kernel void kernel_ssm_scan_f32(
12551254 device const void * src4,
12561255 device const void * src5,
12571256 device float * dst,
1258- constant int64_t & d_state,
1259- constant int64_t & d_inner,
1260- constant int64_t & n_seq_tokens,
1261- constant int64_t & n_seqs,
1262- constant uint64_t & nb00,
1263- constant uint64_t & nb01,
1264- constant uint64_t & nb02,
1265- constant uint64_t & nb10,
1266- constant uint64_t & nb11,
1267- constant uint64_t & nb12,
1268- constant uint64_t & nb13,
1269- constant uint64_t & nb20,
1270- constant uint64_t & nb21,
1271- constant uint64_t & nb22,
1272- constant uint64_t & nb30,
1273- constant uint64_t & nb31,
1274- constant uint64_t & nb40,
1275- constant uint64_t & nb41,
1276- constant uint64_t & nb42,
1277- constant uint64_t & nb50,
1278- constant uint64_t & nb51,
1279- constant uint64_t & nb52,
1257+ constant ggml_metal_kargs_ssm_scan & args,
12801258 uint3 tgpig[[threadgroup_position_in_grid]],
12811259 uint3 tpitg[[thread_position_in_threadgroup]],
12821260 uint3 ntg[[threads_per_threadgroup]]) {
12831261 const int64_t ir = tgpig.x ;
12841262 const int64_t i3 = tgpig.y ;
12851263
1286- const int64_t nc = d_state;
1287- // const int64_t nr = d_inner;
1288- const int64_t n_t = n_seq_tokens;
1289- // const int64_t n_s = n_seqs;
1264+ const int64_t nc = args. d_state ;
1265+ // const int64_t nr = args. d_inner;
1266+ const int64_t n_t = args. n_seq_tokens ;
1267+ // const int64_t n_s = args. n_seqs;
12901268
12911269 for (int64_t i2 = 0 ; i2 < n_t ; ++i2) {
1292- device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
1293- device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
1294- device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
1295- device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
1296- device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
1297- device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
1298- device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
1299- device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);
1270+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*args. nb01 + i3*args. nb02 );
1271+ device const float * x = (device const float *) ((device const char *) src1 + ir*args. nb10 + i2*args. nb11 + i3*args. nb12 );
1272+ device const float * dt = (device const float *) ((device const char *) src2 + ir*args. nb20 + i2*args. nb21 + i3*args. nb22 );
1273+ device const float * A = (device const float *) ((device const char *) src3 + ir*args. nb31 );
1274+ device const float * B = (device const float *) ((device const char *) src4 + i2*args. nb41 + i3*args. nb42 );
1275+ device const float * C = (device const float *) ((device const char *) src5 + i2*args. nb51 + i3*args. nb52 );
1276+ device float * y = (device float *) ((device char *) dst + ir*args. nb10 + i2*args. nb11 + i3*args. nb12 ); // TODO: do not use src1 strides
1277+ device float * s = (device float *) ((device char *) dst + ir*args. nb01 + i3*args. nb02 + args. nb13 );
13001278
13011279 if (i2 > 0 ) {
13021280 s0 = s;
0 commit comments