@@ -4739,12 +4739,13 @@ kernel void kernel_argsort_merge_f32_i32(
47394739 uint3 tgpig[[threadgroup_position_in_grid]],
47404740 ushort3 tpitg[[thread_position_in_threadgroup]],
47414741 ushort3 ntg[[threads_per_threadgroup]]) {
4742- int im = tgpig[0 ] / args.ne01 ;
4743- int i01 = tgpig[0 ] % args.ne01 ;
4744- int i02 = tgpig[1 ];
4745- int i03 = tgpig[2 ];
47464742
4747- const int start = im * (2 *args.len );
4743+ const int im = tgpig[0 ] / args.ne01 ;
4744+ const int i01 = tgpig[0 ] % args.ne01 ;
4745+ const int i02 = tgpig[1 ];
4746+ const int i03 = tgpig[2 ];
4747+
4748+ const int start = im * (2 * args.len );
47484749
47494750 const int len0 = MIN (args.len , MAX (0 , args.ne00 - (int )(start)));
47504751 const int len1 = MIN (args.len , MAX (0 , args.ne00 - (int )(start + args.len )));
@@ -4768,54 +4769,101 @@ kernel void kernel_argsort_merge_f32_i32(
47684769 + args.nb02 *i02
47694770 + args.nb03 *i03);
47704771
4771- for (int k = tpitg.x ; k < (int ) total; k += ntg.x ) {
4772- // find partition (i,j) such that i+j = k
4773- int low = k > len1 ? k - len1 : 0 ;
4774- int high = MIN (k, len0);
4772+ if (total == 0 ) {
4773+ return ;
4774+ }
47754775
4776- while (low < high) {
4777- const int mid = (low + high) >> 1 ;
4776+ const int chunk = (total + ntg.x - 1 ) / ntg.x ;
47784777
4779- const int32_t idx0 = tmp0[mid] ;
4780- const int32_t idx1 = tmp1[k - mid - 1 ] ;
4778+ const int k0 = tpitg. x * chunk ;
4779+ const int k1 = min (k0 + chunk, total) ;
47814780
4782- const float val0 = src0_row[idx0];
4783- const float val1 = src0_row[idx1];
4781+ if (k0 >= total) {
4782+ return ;
4783+ }
47844784
4785- if (order == GGML_SORT_ORDER_ASC) {
4786- if (val0 <= val1) {
4787- low = mid + 1 ;
4788- } else {
4789- high = mid;
4790- }
4791- } else {
4792- if (val0 >= val1) {
4793- low = mid + 1 ;
4794- } else {
4795- high = mid;
4796- }
4797- }
4785+ int low = k0 > len1 ? k0 - len1 : 0 ;
4786+ int high = MIN (k0, len0);
4787+
4788+ // binary-search partition (i, j) such that i + j = k
4789+ while (low < high) {
4790+ const int mid = (low + high) >> 1 ;
4791+
4792+ const int32_t idx0 = tmp0[mid];
4793+ const int32_t idx1 = tmp1[k0 - mid - 1 ];
4794+
4795+ const float val0 = src0_row[idx0];
4796+ const float val1 = src0_row[idx1];
4797+
4798+ bool take_left;
4799+ if (order == GGML_SORT_ORDER_ASC) {
4800+ take_left = (val0 <= val1);
4801+ } else {
4802+ take_left = (val0 >= val1);
47984803 }
47994804
4800- const int i = low;
4801- const int j = k - i;
4805+ if (take_left) {
4806+ low = mid + 1 ;
4807+ } else {
4808+ high = mid;
4809+ }
4810+ }
4811+
4812+ int i = low;
4813+ int j = k0 - i;
4814+
4815+ // keep the merge fronts into registers
4816+ int32_t idx0 = 0 ;
4817+ float val0 = 0 .0f ;
4818+ if (i < len0) {
4819+ idx0 = tmp0[i];
4820+ val0 = src0_row[idx0];
4821+ }
4822+
4823+ int32_t idx1 = 0 ;
4824+ float val1 = 0 .0f ;
4825+ if (j < len1) {
4826+ idx1 = tmp1[j];
4827+ val1 = src0_row[idx1];
4828+ }
48024829
4830+ for (int k = k0; k < k1; ++k) {
48034831 int32_t out_idx;
48044832
48054833 if (i >= len0) {
4806- out_idx = tmp1[j];
4834+ while (k < k1) {
4835+ dst[k++] = tmp1[j++];
4836+ }
4837+ break ;
48074838 } else if (j >= len1) {
4808- out_idx = tmp0[i];
4839+ while (k < k1) {
4840+ dst[k++] = tmp0[i++];
4841+ }
4842+ break ;
48094843 } else {
4810- const int32_t idx0 = tmp0[i];
4811- const int32_t idx1 = tmp1[j];
4844+ bool take_left;
48124845
4813- const float val0 = src0_row[idx0];
4814- const float val1 = src0_row[idx1];
4846+ if (order == GGML_SORT_ORDER_ASC) {
4847+ take_left = (val0 <= val1);
4848+ } else {
4849+ take_left = (val0 >= val1);
4850+ }
48154851
4816- out_idx = (order == GGML_SORT_ORDER_ASC)
4817- ? (val0 <= val1 ? idx0 : idx1)
4818- : (val0 >= val1 ? idx0 : idx1);
4852+ if (take_left) {
4853+ out_idx = idx0;
4854+ ++i;
4855+ if (i < len0) {
4856+ idx0 = tmp0[i];
4857+ val0 = src0_row[idx0];
4858+ }
4859+ } else {
4860+ out_idx = idx1;
4861+ ++j;
4862+ if (j < len1) {
4863+ idx1 = tmp1[j];
4864+ val1 = src0_row[idx1];
4865+ }
4866+ }
48194867 }
48204868
48214869 dst[k] = out_idx;
0 commit comments