Skip to content

Commit 45c6ef7

Browse files
authored
metal : support argsort for ne00 > 1024 (ggml-org#17247)
* metal : refactor argsort * cont : sort chunks * cont : merge sorted buckets * cont : cleanup
1 parent 2606b0a commit 45c6ef7

File tree

8 files changed

+265
-44
lines changed

8 files changed

+265
-44
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,34 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
943943
return res;
944944
}
945945

946+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
947+
assert(op->op == GGML_OP_ARGSORT);
948+
949+
char base[256];
950+
char name[256];
951+
952+
ggml_sort_order order = (ggml_sort_order) op->op_params[0];
953+
954+
const char * order_str = "undefined";
955+
switch (order) {
956+
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
957+
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
958+
default: GGML_ABORT("fatal error");
959+
};
960+
961+
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
962+
snprintf(name, 256, "%s", base);
963+
964+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
965+
if (res) {
966+
return res;
967+
}
968+
969+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
970+
971+
return res;
972+
}
973+
946974
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
947975
ggml_metal_library_t lib,
948976
const struct ggml_tensor * op,

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_me
125125
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
126126
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
127127
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
128+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
128129
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
129130
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
130131
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -904,8 +904,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
904904
case GGML_OP_LEAKY_RELU:
905905
return op->src[0]->type == GGML_TYPE_F32;
906906
case GGML_OP_ARGSORT:
907-
// TODO: Support arbitrary column width
908-
return op->src[0]->ne[0] <= 1024;
909907
case GGML_OP_ARANGE:
910908
return true;
911909
case GGML_OP_FLASH_ATTN_EXT:

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -793,10 +793,28 @@ typedef struct {
793793
} ggml_metal_kargs_leaky_relu;
794794

795795
typedef struct {
796-
int64_t ncols;
797-
int64_t ncols_pad;
796+
int64_t ne00;
797+
int64_t ne01;
798+
int64_t ne02;
799+
int64_t ne03;
800+
uint64_t nb00;
801+
uint64_t nb01;
802+
uint64_t nb02;
803+
uint64_t nb03;
798804
} ggml_metal_kargs_argsort;
799805

806+
typedef struct {
807+
int64_t ne00;
808+
int64_t ne01;
809+
int64_t ne02;
810+
int64_t ne03;
811+
uint64_t nb00;
812+
uint64_t nb01;
813+
uint64_t nb02;
814+
uint64_t nb03;
815+
int32_t len;
816+
} ggml_metal_kargs_argsort_merge;
817+
800818
typedef struct {
801819
int64_t ne0;
802820
float start;

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3530,38 +3530,95 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
35303530
ggml_metal_library_t lib = ctx->lib;
35313531
ggml_metal_encoder_t enc = ctx->enc;
35323532

3533+
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
3534+
35333535
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
35343536
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
35353537
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
35363538
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
35373539

3540+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
3541+
35383542
// bitonic sort requires the number of elements to be power of 2
3539-
int64_t ne00_padded = 1;
3540-
while (ne00_padded < ne00) {
3541-
ne00_padded *= 2;
3543+
int nth = 1;
3544+
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3545+
nth *= 2;
35423546
}
35433547

3544-
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
3545-
3546-
const int64_t nrows = ggml_nrows(op->src[0]);
3548+
const int nptg = (ne00 + nth - 1)/nth;
35473549

35483550
// Metal kernels require the buffer size to be multiple of 16 bytes
35493551
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3550-
const size_t smem = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
3552+
const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
3553+
3554+
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3555+
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
3556+
3557+
ggml_metal_buffer_id bid_tmp = bid_dst;
3558+
bid_tmp.offs += ggml_nbytes(op);
3559+
3560+
if ((int) ceil(std::log(nptg) / std::log(2)) % 2 == 1) {
3561+
std::swap(bid_dst, bid_tmp);
3562+
}
35513563

35523564
ggml_metal_kargs_argsort args = {
3553-
/*.ncols =*/ ne00,
3554-
/*.ncols_pad =*/ ne00_padded
3565+
/*.ne00 =*/ ne00,
3566+
/*.ne01 =*/ ne01,
3567+
/*.ne02 =*/ ne02,
3568+
/*.ne03 =*/ ne03,
3569+
/*.nb00 =*/ nb00,
3570+
/*.nb01 =*/ nb01,
3571+
/*.nb02 =*/ nb02,
3572+
/*.nb03 =*/ nb03,
35553573
};
35563574

35573575
ggml_metal_encoder_set_pipeline(enc, pipeline);
35583576
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3559-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3560-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3577+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3578+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
35613579

35623580
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
35633581

3564-
ggml_metal_encoder_dispatch_threadgroups(enc, 1, nrows, 1, ne00_padded, 1, 1);
3582+
ggml_metal_encoder_dispatch_threadgroups(enc, nptg*ne01, ne02, ne03, nth, 1, 1);
3583+
3584+
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
3585+
3586+
int len = nth;
3587+
3588+
while (len < ne00) {
3589+
ggml_metal_op_concurrency_reset(ctx);
3590+
3591+
ggml_metal_kargs_argsort_merge args_merge = {
3592+
.ne00 = ne00,
3593+
.ne01 = ne01,
3594+
.ne02 = ne02,
3595+
.ne03 = ne03,
3596+
.nb00 = nb00,
3597+
.nb01 = nb01,
3598+
.nb02 = nb02,
3599+
.nb03 = nb03,
3600+
.len = len,
3601+
};
3602+
3603+
// merges per row
3604+
const int nm = (ne00 + 2*len - 1) / (2*len);
3605+
3606+
const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
3607+
3608+
ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
3609+
ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
3610+
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3611+
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3612+
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
3613+
3614+
ggml_metal_encoder_set_threadgroup_memory_size(enc, 0, 0);
3615+
3616+
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
3617+
3618+
std::swap(bid_dst, bid_tmp);
3619+
3620+
len <<= 1;
3621+
}
35653622

35663623
return 1;
35673624
}

ggml/src/ggml-metal/ggml-metal.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
197197
res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
198198
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
199199
} break;
200+
case GGML_OP_ARGSORT:
201+
{
202+
res *= 2;
203+
} break;
200204
default:
201205
break;
202206
}

0 commit comments

Comments
 (0)