@@ -3530,38 +3530,95 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
35303530 ggml_metal_library_t lib = ctx->lib ;
35313531 ggml_metal_encoder_t enc = ctx->enc ;
35323532
3533+ GGML_ASSERT (ggml_is_contiguous_rows (op->src [0 ]));
3534+
35333535 GGML_TENSOR_LOCALS ( int32_t , ne0, op->src [0 ], ne);
35343536 GGML_TENSOR_LOCALS (uint64_t , nb0, op->src [0 ], nb);
35353537 GGML_TENSOR_LOCALS ( int32_t , ne, op, ne);
35363538 GGML_TENSOR_LOCALS (uint32_t , nb, op, nb);
35373539
3540+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort (lib, op);
3541+
35383542 // bitonic sort requires the number of elements to be power of 2
3539- int64_t ne00_padded = 1 ;
3540- while (ne00_padded < ne00) {
3541- ne00_padded *= 2 ;
3543+ int nth = 1 ;
3544+ while (nth < ne00 && 2 *nth <= ggml_metal_pipeline_max_theads_per_threadgroup (pipeline) ) {
3545+ nth *= 2 ;
35423546 }
35433547
3544- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort (lib, op);
3545-
3546- const int64_t nrows = ggml_nrows (op->src [0 ]);
3548+ const int nptg = (ne00 + nth - 1 )/nth;
35473549
35483550 // Metal kernels require the buffer size to be multiple of 16 bytes
35493551 // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3550- const size_t smem = GGML_PAD (ne00_padded*sizeof (int32_t ), 16 );
3552+ const size_t smem = GGML_PAD (nth*sizeof (int32_t ), 16 );
3553+
3554+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id (op->src [0 ]);
3555+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id (op);
3556+
3557+ ggml_metal_buffer_id bid_tmp = bid_dst;
3558+ bid_tmp.offs += ggml_nbytes (op);
3559+
3560+ if ((int ) ceil (std::log (nptg) / std::log (2 )) % 2 == 1 ) {
3561+ std::swap (bid_dst, bid_tmp);
3562+ }
35513563
35523564 ggml_metal_kargs_argsort args = {
3553- /* .ncols =*/ ne00,
3554- /* .ncols_pad =*/ ne00_padded
3565+ /* .ne00 =*/ ne00,
3566+ /* .ne01 =*/ ne01,
3567+ /* .ne02 =*/ ne02,
3568+ /* .ne03 =*/ ne03,
3569+ /* .nb00 =*/ nb00,
3570+ /* .nb01 =*/ nb01,
3571+ /* .nb02 =*/ nb02,
3572+ /* .nb03 =*/ nb03,
35553573 };
35563574
35573575 ggml_metal_encoder_set_pipeline (enc, pipeline);
35583576 ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
3559- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op-> src [ 0 ]) , 1 );
3560- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op), 2 );
3577+ ggml_metal_encoder_set_buffer (enc, bid_src0 , 1 );
3578+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2 );
35613579
35623580 ggml_metal_encoder_set_threadgroup_memory_size (enc, smem, 0 );
35633581
3564- ggml_metal_encoder_dispatch_threadgroups (enc, 1 , nrows, 1 , ne00_padded, 1 , 1 );
3582+ ggml_metal_encoder_dispatch_threadgroups (enc, nptg*ne01, ne02, ne03, nth, 1 , 1 );
3583+
3584+ ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge (lib, op);
3585+
3586+ int len = nth;
3587+
3588+ while (len < ne00) {
3589+ ggml_metal_op_concurrency_reset (ctx);
3590+
3591+ ggml_metal_kargs_argsort_merge args_merge = {
3592+ .ne00 = ne00,
3593+ .ne01 = ne01,
3594+ .ne02 = ne02,
3595+ .ne03 = ne03,
3596+ .nb00 = nb00,
3597+ .nb01 = nb01,
3598+ .nb02 = nb02,
3599+ .nb03 = nb03,
3600+ .len = len,
3601+ };
3602+
3603+ // merges per row
3604+ const int nm = (ne00 + 2 *len - 1 ) / (2 *len);
3605+
3606+ const int nth = std::min (512 , ggml_metal_pipeline_max_theads_per_threadgroup (pipeline_merge));
3607+
3608+ ggml_metal_encoder_set_pipeline (enc, pipeline_merge);
3609+ ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof (args_merge), 0 );
3610+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1 );
3611+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2 );
3612+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3 );
3613+
3614+ ggml_metal_encoder_set_threadgroup_memory_size (enc, 0 , 0 );
3615+
3616+ ggml_metal_encoder_dispatch_threadgroups (enc, nm*ne01, ne02, ne03, nth, 1 , 1 );
3617+
3618+ std::swap (bid_dst, bid_tmp);
3619+
3620+ len <<= 1 ;
3621+ }
35653622
35663623 return 1 ;
35673624}
0 commit comments