@@ -577,6 +577,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
577577 ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy (lib, op->src [0 ]->type , op->type );
578578
579579 ggml_metal_kargs_cpy args = {
580+ /* .nk0 =*/ ne00,
580581 /* .ne00 =*/ ne00,
581582 /* .ne01 =*/ ne01,
582583 /* .ne02 =*/ ne02,
@@ -906,23 +907,31 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
906907 ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows (lib, op->src [0 ]->type );
907908
908909 ggml_metal_kargs_get_rows args = {
909- /* .ne00 =*/ ne00,
910- /* .nb01 =*/ nb01,
911- /* .nb02 =*/ nb02,
912- /* .ne10 =*/ ne10,
913- /* .nb10 =*/ nb10,
914- /* .nb11 =*/ nb11,
915- /* .nb1 =*/ nb1,
916- /* .nb2 =*/ nb2,
910+ /* .ne00t =*/ ggml_is_quantized (op->src [0 ]->type ) ? ne00/16 : ne00,
911+ /* .ne00 =*/ ne00,
912+ /* .nb01 =*/ nb01,
913+ /* .nb02 =*/ nb02,
914+ /* .nb03 =*/ nb03,
915+ /* .ne10 =*/ ne10,
916+ /* .nb10 =*/ nb10,
917+ /* .nb11 =*/ nb11,
918+ /* .nb12 =*/ nb12,
919+ /* .nb1 =*/ nb1,
920+ /* .nb2 =*/ nb2,
921+ /* .nb3 =*/ nb3,
917922 };
918923
924+ const int nth = std::min (args.ne00t , ggml_metal_pipeline_max_theads_per_threadgroup (pipeline));
925+
926+ const int nw0 = (args.ne00t + nth - 1 )/nth;
927+
919928 ggml_metal_encoder_set_pipeline (enc, pipeline);
920929 ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
921930 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 1 );
922931 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), 2 );
923932 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op), 3 );
924933
925- ggml_metal_encoder_dispatch_threadgroups (enc, ne10, ne11, ne12, 32 , 1 , 1 );
934+ ggml_metal_encoder_dispatch_threadgroups (enc, nw0* ne10, ne11, ne12, nth , 1 , 1 );
926935
927936 return 1 ;
928937}
@@ -1117,7 +1126,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
11171126 ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
11181127 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 1 );
11191128 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), 2 );
1120- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op), 3 );
1129+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op), 3 );
11211130
11221131 ggml_metal_encoder_dispatch_threadgroups (enc, ne01, ne1, ne02, 1 , 1 , 1 );
11231132
@@ -1172,25 +1181,36 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
11721181 /* .n_seq_tokens =*/ n_seq_tokens,
11731182 /* .n_seqs =*/ n_seqs,
11741183 /* .s_off =*/ ggml_nelements (op->src [1 ]) * sizeof (float ),
1184+ /* .nb00 =*/ nb00,
11751185 /* .nb01 =*/ nb01,
11761186 /* .nb02 =*/ nb02,
11771187 /* .nb03 =*/ nb03,
1188+ /* .nb10 =*/ nb10,
11781189 /* .nb11 =*/ nb11,
11791190 /* .nb12 =*/ nb12,
1191+ /* .ns12 =*/ nb12/nb10,
11801192 /* .nb13 =*/ nb13,
1193+ /* .nb20 =*/ nb20,
11811194 /* .nb21 =*/ nb21,
1195+ /* .ns21 =*/ nb21/nb20,
11821196 /* .nb22 =*/ nb22,
1197+ /* .ne30 =*/ ne30,
11831198 /* .nb31 =*/ nb31,
11841199 /* .nb41 =*/ nb41,
11851200 /* .nb42 =*/ nb42,
1201+ /* .ns42 =*/ nb42/nb40,
11861202 /* .nb43 =*/ nb43,
11871203 /* .nb51 =*/ nb51,
11881204 /* .nb52 =*/ nb52,
1205+ /* .ns52 =*/ nb52/nb50,
11891206 /* .nb53 =*/ nb53,
1207+ /* .nb0 =*/ nb0,
11901208 };
11911209
11921210 ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan (lib, op);
11931211
1212+ GGML_ASSERT (d_state <= ggml_metal_pipeline_max_theads_per_threadgroup (pipeline));
1213+
11941214 const size_t sms = ggml_metal_pipeline_get_smem (pipeline);
11951215
11961216 ggml_metal_encoder_set_pipeline (enc, pipeline);
@@ -1206,13 +1226,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
12061226
12071227 ggml_metal_encoder_set_threadgroup_memory_size (enc, sms, 0 );
12081228
1209- if (ne30 == 1 ) {
1210- // Mamba-2
1211- ggml_metal_encoder_dispatch_threadgroups (enc, d_inner, n_head, n_seqs, d_state, 1 , 1 );
1212- } else {
1213- GGML_ASSERT (d_inner == 1 );
1214- ggml_metal_encoder_dispatch_threadgroups (enc, n_head, n_seqs, 1 , d_state, 1 , 1 );
1215- }
1229+ ggml_metal_encoder_dispatch_threadgroups (enc, d_inner, n_head, n_seqs, d_state, 1 , 1 );
12161230
12171231 return 1 ;
12181232}
@@ -1273,37 +1287,35 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
12731287
12741288 GGML_ASSERT (ne00 % ggml_blck_size (op->src [0 ]->type ) == 0 );
12751289
1276- // TODO: support
1277- // const int32_t nk00 = ne00/ggml_blck_size(op->type);
1278- const int32_t nk00 = ne00;
1279-
1280- int nth = 32 ; // SIMD width
1281-
1282- while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup (pipeline)) {
1283- nth *= 2 ;
1290+ int64_t nk0 = ne00;
1291+ if (ggml_is_quantized (op->src [0 ]->type )) {
1292+ nk0 = ne00/16 ;
1293+ } else if (ggml_is_quantized (op->type )) {
1294+ nk0 = ne00/ggml_blck_size (op->type );
12841295 }
12851296
1286- nth = std::min (nth , ggml_metal_pipeline_max_theads_per_threadgroup (pipeline));
1297+ int nth = std::min< int >(nk0 , ggml_metal_pipeline_max_theads_per_threadgroup (pipeline));
12871298
12881299 // when rows are small, we can batch them together in a single threadgroup
12891300 int nrptg = 1 ;
12901301
12911302 // TODO: relax this constraint in the future
12921303 if (ggml_blck_size (op->src [0 ]->type ) == 1 && ggml_blck_size (op->type ) == 1 ) {
1293- if (nth > nk00 ) {
1294- nrptg = (nth + nk00 - 1 )/nk00 ;
1295- nth = nk00 ;
1304+ if (nth > nk0 ) {
1305+ nrptg = (nth + nk0 - 1 )/nk0 ;
1306+ nth = nk0 ;
12961307
12971308 if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup (pipeline)) {
12981309 nrptg--;
12991310 }
13001311 }
13011312 }
13021313
1303- nth = std::min (nth, nk00 );
1314+ nth = std::min< int > (nth, nk0 );
13041315
13051316 ggml_metal_kargs_cpy args = {
1306- /* .ne00 =*/ nk00,
1317+ /* .nk0 =*/ nk0,
1318+ /* .ne00 =*/ ne00,
13071319 /* .ne01 =*/ ne01,
13081320 /* .ne02 =*/ ne02,
13091321 /* .ne03 =*/ ne03,
@@ -1321,12 +1333,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
13211333 /* .nb3 =*/ nb3,
13221334 };
13231335
1336+ const int nw0 = nrptg == 1 ? (nk0 + nth - 1 )/nth : 1 ;
1337+
13241338 ggml_metal_encoder_set_pipeline (enc, pipeline);
13251339 ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
13261340 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 1 );
13271341 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op), 2 );
13281342
1329- ggml_metal_encoder_dispatch_threadgroups (enc, ne01, ne02, ne03, nth, nrptg, 1 );
1343+ ggml_metal_encoder_dispatch_threadgroups (enc, nw0*( ne01 + nrptg - 1 )/nrptg , ne02, ne03, nth, nrptg, 1 );
13301344
13311345 return 1 ;
13321346}
0 commit comments