Skip to content

Commit de56279

Browse files
authored
vulkan: Optimize argsort (#15354)
- Launch an appropriate number of invocations (next larger power of two). 32 invocations is common and the barrier is much cheaper there. - Specialize for "needs bounds checking" vs not. - Make the code less branchy and [[unroll]] the loops. In the final code, I see no branches inside the main loop (only predicated stores) when needs_bounds_check is false. - Always sort ascending, then apply the ascending vs descending option when doing the final stores to memory. - Copy the values into shared memory, makes them slightly cheaper to access.
1 parent 65349f2 commit de56279

File tree

3 files changed

+51
-42
lines changed

3 files changed

+51
-42
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ enum vk_conv_shapes {
345345
CONV_SHAPE_COUNT,
346346
};
347347

348+
static constexpr uint32_t num_argsort_pipelines = 11;
349+
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
350+
348351
struct vk_device_struct {
349352
std::recursive_mutex mutex;
350353

@@ -505,7 +508,7 @@ struct vk_device_struct {
505508
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
506509
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
507510
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
508-
vk_pipeline pipeline_argsort_f32;
511+
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
509512
vk_pipeline pipeline_sum_rows_f32;
510513
vk_pipeline pipeline_argmax_f32;
511514
vk_pipeline pipeline_count_equal_i32;
@@ -870,7 +873,6 @@ struct vk_op_soft_max_push_constants {
870873

871874
struct vk_op_argsort_push_constants {
872875
uint32_t ncols;
873-
uint32_t ncols_pad;
874876
int32_t order;
875877
};
876878

@@ -3099,7 +3101,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
30993101
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
31003102
}
31013103

3102-
ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
3104+
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
3105+
ggml_vk_create_pipeline(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true);
3106+
}
31033107

31043108
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
31053109

@@ -7160,7 +7164,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
71607164
}
71617165
case GGML_OP_ARGSORT:
71627166
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
7163-
return ctx->device->pipeline_argsort_f32;
7167+
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
7168+
return ctx->device->pipeline_argsort_f32[idx];
71647169
}
71657170
return nullptr;
71667171
case GGML_OP_SUM:
@@ -8485,16 +8490,8 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
84858490

84868491
uint32_t ncols = src0->ne[0];
84878492

8488-
uint32_t ncols_pad = 1;
8489-
while (ncols_pad < ncols) {
8490-
ncols_pad *= 2;
8491-
}
8492-
8493-
GGML_ASSERT(ncols_pad <= 1024);
8494-
84958493
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
84968494
ncols,
8497-
ncols_pad,
84988495
op_params[0],
84998496
}, dryrun);
85008497
}
@@ -11367,6 +11364,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1136711364
case GGML_OP_OPT_STEP_ADAMW:
1136811365
case GGML_OP_OPT_STEP_SGD:
1136911366
return op->src[0]->type == GGML_TYPE_F32;
11367+
case GGML_OP_ARGSORT:
11368+
return op->ne[0] <= max_argsort_cols;
1137011369
case GGML_OP_UPSCALE:
1137111370
case GGML_OP_ACC:
1137211371
case GGML_OP_CONCAT:
@@ -11376,7 +11375,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1137611375
case GGML_OP_DIAG_MASK_INF:
1137711376
case GGML_OP_SOFT_MAX:
1137811377
case GGML_OP_SOFT_MAX_BACK:
11379-
case GGML_OP_ARGSORT:
1138011378
case GGML_OP_SUM:
1138111379
case GGML_OP_SUM_ROWS:
1138211380
case GGML_OP_ARGMAX:
Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,79 @@
11
#version 450
2+
#extension GL_EXT_control_flow_attributes : enable
23

34
#include "types.comp"
45

5-
#define BLOCK_SIZE 1024
6+
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
7+
layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;
68
#define ASC 0
79

8-
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
10+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
911

1012
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
1113
layout (binding = 1) buffer D {int data_d[];};
1214

1315
layout (push_constant) uniform parameter {
1416
uint ncols;
15-
uint ncols_pad;
1617
uint order;
1718
} p;
1819

1920
shared int dst_row[BLOCK_SIZE];
21+
shared A_TYPE a_sh[BLOCK_SIZE];
2022

2123
void swap(uint idx0, uint idx1) {
2224
int tmp = dst_row[idx0];
2325
dst_row[idx0] = dst_row[idx1];
2426
dst_row[idx1] = tmp;
2527
}
2628

27-
void main() {
29+
void argsort(bool needs_bounds_check) {
2830
// bitonic sort
2931
const int col = int(gl_LocalInvocationID.x);
3032
const uint row = gl_WorkGroupID.y;
3133

3234
const uint row_offset = row * p.ncols;
3335

3436
// initialize indices
35-
if (col < p.ncols_pad) {
36-
dst_row[col] = col;
37-
}
37+
dst_row[col] = col;
38+
a_sh[col] = data_a[row_offset + col];
3839
barrier();
3940

40-
for (uint k = 2; k <= p.ncols_pad; k *= 2) {
41-
for (uint j = k / 2; j > 0; j /= 2) {
42-
const uint ixj = col ^ j;
43-
if (col < p.ncols_pad && ixj > col) {
44-
if ((col & k) == 0) {
45-
if (dst_row[col] >= p.ncols ||
46-
(dst_row[ixj] < p.ncols && (p.order == ASC ?
47-
data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] :
48-
data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]]))
49-
) {
50-
swap(col, ixj);
51-
}
52-
} else {
53-
if (dst_row[ixj] >= p.ncols ||
54-
(dst_row[col] < p.ncols && (p.order == ASC ?
55-
data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] :
56-
data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]]))
57-
) {
58-
swap(col, ixj);
59-
}
60-
}
41+
uint num_outer_loop_iters = BLOCK_SIZE_LOG2;
42+
[[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
43+
uint num_inner_loop_iters = outer_idx + 1;
44+
[[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
45+
const int ixj = int(col ^ j);
46+
47+
int idx_0 = (col & k) == 0 ? col : ixj;
48+
int idx_1 = (col & k) == 0 ? ixj : col;
49+
50+
int sh_idx_0 = dst_row[idx_0];
51+
int sh_idx_1 = dst_row[idx_1];
52+
bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false;
53+
bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false;
54+
55+
if ((idx_0_oob ||
56+
(!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) {
57+
swap(idx_0, idx_1);
6158
}
59+
6260
barrier();
6361
}
6462
}
6563

6664
if (col < p.ncols) {
67-
data_d[row_offset + col] = dst_row[col];
65+
if (p.order == ASC) {
66+
data_d[row_offset + col] = dst_row[col];
67+
} else {
68+
data_d[row_offset + p.ncols - col - 1] = dst_row[col];
69+
}
70+
}
71+
}
72+
73+
void main() {
74+
if (p.ncols == BLOCK_SIZE) {
75+
argsort(false);
76+
} else {
77+
argsort(true);
6878
}
6979
}

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6028,6 +6028,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
60286028
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
60296029
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
60306030
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
6031+
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
60316032
}
60326033

60336034
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {

0 commit comments

Comments
 (0)