Skip to content

Commit 229bf68

Browse files
authored
cuda : fix argsort with 64k+ rows (ggml-org#16849)
1 parent d739511 commit 229bf68

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

ggml/src/ggml-cuda/argsort.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ template<ggml_sort_order order>
8787
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
8888
// bitonic sort
8989
int col = threadIdx.x;
90-
int row = blockIdx.y;
90+
int row = blockIdx.x;
9191

9292
if (col >= ncols_pad) {
9393
return;
@@ -151,7 +151,7 @@ static void argsort_f32_i32_cuda_bitonic(const float * x,
151151
const int ncols_pad = next_power_of_2(ncols);
152152

153153
const dim3 block_dims(ncols_pad, 1, 1);
154-
const dim3 block_nums(1, nrows, 1);
154+
const dim3 block_nums(nrows, 1, 1);
155155
const size_t shared_mem = ncols_pad * sizeof(int);
156156

157157
// FIXME: this limit could be raised by ~2-4x on Ampere or newer

tests/test-backend-ops.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7111,7 +7111,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
71117111
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
71127112
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
71137113
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
7114-
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // bailingmoe2 (group selection)
7114+
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024
7115+
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
71157116
}
71167117

71177118
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {

0 commit comments

Comments
 (0)