Skip to content

Commit 3347e6d

Browse files
authored
metal : faster argsort (#17315)
* metal : faster argsort * cont : keep data in registers
1 parent 1a13964 commit 3347e6d

File tree

2 files changed

+87
-41
lines changed

2 files changed

+87
-41
lines changed

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3726,8 +3726,6 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
37263726
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
37273727
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
37283728

3729-
ggml_metal_encoder_set_threadgroup_memory_size(enc, 0, 0);
3730-
37313729
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
37323730

37333731
std::swap(bid_dst, bid_tmp);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 87 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4739,12 +4739,13 @@ kernel void kernel_argsort_merge_f32_i32(
47394739
uint3 tgpig[[threadgroup_position_in_grid]],
47404740
ushort3 tpitg[[thread_position_in_threadgroup]],
47414741
ushort3 ntg[[threads_per_threadgroup]]) {
4742-
int im = tgpig[0] / args.ne01;
4743-
int i01 = tgpig[0] % args.ne01;
4744-
int i02 = tgpig[1];
4745-
int i03 = tgpig[2];
47464742

4747-
const int start = im * (2*args.len);
4743+
const int im = tgpig[0] / args.ne01;
4744+
const int i01 = tgpig[0] % args.ne01;
4745+
const int i02 = tgpig[1];
4746+
const int i03 = tgpig[2];
4747+
4748+
const int start = im * (2 * args.len);
47484749

47494750
const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
47504751
const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
@@ -4768,54 +4769,101 @@ kernel void kernel_argsort_merge_f32_i32(
47684769
+ args.nb02*i02
47694770
+ args.nb03*i03);
47704771

4771-
for (int k = tpitg.x; k < (int) total; k += ntg.x) {
4772-
// find partition (i,j) such that i+j = k
4773-
int low = k > len1 ? k - len1 : 0;
4774-
int high = MIN(k, len0);
4772+
if (total == 0) {
4773+
return;
4774+
}
47754775

4776-
while (low < high) {
4777-
const int mid = (low + high) >> 1;
4776+
const int chunk = (total + ntg.x - 1) / ntg.x;
47784777

4779-
const int32_t idx0 = tmp0[mid];
4780-
const int32_t idx1 = tmp1[k - mid - 1];
4778+
const int k0 = tpitg.x * chunk;
4779+
const int k1 = min(k0 + chunk, total);
47814780

4782-
const float val0 = src0_row[idx0];
4783-
const float val1 = src0_row[idx1];
4781+
if (k0 >= total) {
4782+
return;
4783+
}
47844784

4785-
if (order == GGML_SORT_ORDER_ASC) {
4786-
if (val0 <= val1) {
4787-
low = mid + 1;
4788-
} else {
4789-
high = mid;
4790-
}
4791-
} else {
4792-
if (val0 >= val1) {
4793-
low = mid + 1;
4794-
} else {
4795-
high = mid;
4796-
}
4797-
}
4785+
int low = k0 > len1 ? k0 - len1 : 0;
4786+
int high = MIN(k0, len0);
4787+
4788+
// binary-search partition (i, j) such that i + j = k
4789+
while (low < high) {
4790+
const int mid = (low + high) >> 1;
4791+
4792+
const int32_t idx0 = tmp0[mid];
4793+
const int32_t idx1 = tmp1[k0 - mid - 1];
4794+
4795+
const float val0 = src0_row[idx0];
4796+
const float val1 = src0_row[idx1];
4797+
4798+
bool take_left;
4799+
if (order == GGML_SORT_ORDER_ASC) {
4800+
take_left = (val0 <= val1);
4801+
} else {
4802+
take_left = (val0 >= val1);
47984803
}
47994804

4800-
const int i = low;
4801-
const int j = k - i;
4805+
if (take_left) {
4806+
low = mid + 1;
4807+
} else {
4808+
high = mid;
4809+
}
4810+
}
4811+
4812+
int i = low;
4813+
int j = k0 - i;
4814+
4815+
// keep the merge fronts into registers
4816+
int32_t idx0 = 0;
4817+
float val0 = 0.0f;
4818+
if (i < len0) {
4819+
idx0 = tmp0[i];
4820+
val0 = src0_row[idx0];
4821+
}
4822+
4823+
int32_t idx1 = 0;
4824+
float val1 = 0.0f;
4825+
if (j < len1) {
4826+
idx1 = tmp1[j];
4827+
val1 = src0_row[idx1];
4828+
}
48024829

4830+
for (int k = k0; k < k1; ++k) {
48034831
int32_t out_idx;
48044832

48054833
if (i >= len0) {
4806-
out_idx = tmp1[j];
4834+
while (k < k1) {
4835+
dst[k++] = tmp1[j++];
4836+
}
4837+
break;
48074838
} else if (j >= len1) {
4808-
out_idx = tmp0[i];
4839+
while (k < k1) {
4840+
dst[k++] = tmp0[i++];
4841+
}
4842+
break;
48094843
} else {
4810-
const int32_t idx0 = tmp0[i];
4811-
const int32_t idx1 = tmp1[j];
4844+
bool take_left;
48124845

4813-
const float val0 = src0_row[idx0];
4814-
const float val1 = src0_row[idx1];
4846+
if (order == GGML_SORT_ORDER_ASC) {
4847+
take_left = (val0 <= val1);
4848+
} else {
4849+
take_left = (val0 >= val1);
4850+
}
48154851

4816-
out_idx = (order == GGML_SORT_ORDER_ASC)
4817-
? (val0 <= val1 ? idx0 : idx1)
4818-
: (val0 >= val1 ? idx0 : idx1);
4852+
if (take_left) {
4853+
out_idx = idx0;
4854+
++i;
4855+
if (i < len0) {
4856+
idx0 = tmp0[i];
4857+
val0 = src0_row[idx0];
4858+
}
4859+
} else {
4860+
out_idx = idx1;
4861+
++j;
4862+
if (j < len1) {
4863+
idx1 = tmp1[j];
4864+
val1 = src0_row[idx1];
4865+
}
4866+
}
48194867
}
48204868

48214869
dst[k] = out_idx;

0 commit comments

Comments
 (0)