Skip to content

Commit 052df28

Browse files
authored
vulkan: Handle argsort with a large number of rows (#16851)
1 parent 8b11dee commit 052df28

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,7 @@ struct vk_op_soft_max_push_constants {
10821082

10831083
struct vk_op_argsort_push_constants {
10841084
uint32_t ncols;
1085+
uint32_t nrows;
10851086
int32_t order;
10861087
};
10871088

@@ -8708,6 +8709,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
87088709
break;
87098710
case GGML_OP_ARGSORT:
87108711
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
8712+
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
87118713
break;
87128714
case GGML_OP_IM2COL:
87138715
{
@@ -9954,9 +9956,11 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
99549956
int32_t * op_params = (int32_t *)dst->op_params;
99559957

99569958
uint32_t ncols = src0->ne[0];
9959+
uint32_t nrows = ggml_nrows(src0);
99579960

99589961
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
99599962
ncols,
9963+
nrows,
99609964
op_params[0],
99619965
}, dryrun);
99629966
}

ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];};
1414

1515
layout (push_constant) uniform parameter {
1616
uint ncols;
17+
uint nrows;
1718
uint order;
1819
} p;
1920

@@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) {
2627
dst_row[idx1] = tmp;
2728
}
2829

29-
void argsort(bool needs_bounds_check) {
30+
void argsort(bool needs_bounds_check, const uint row) {
3031
// bitonic sort
3132
const int col = int(gl_LocalInvocationID.x);
32-
const uint row = gl_WorkGroupID.y;
3333

3434
const uint row_offset = row * p.ncols;
3535

@@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) {
7272

7373
void main() {
7474
if (p.ncols == BLOCK_SIZE) {
75-
argsort(false);
75+
uint row = gl_WorkGroupID.y;
76+
while (row < p.nrows) {
77+
argsort(false, row);
78+
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
79+
}
7680
} else {
77-
argsort(true);
81+
uint row = gl_WorkGroupID.y;
82+
while (row < p.nrows) {
83+
argsort(true, row);
84+
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
85+
}
7886
}
7987
}

0 commit comments

Comments
 (0)