Skip to content

Commit a0f3897

Browse files
vulkan: fix top_k bug when there are ties in the input (#17659)
* vulkan: Reduce temporary memory usage for TOP_K - Compute row size for the temp buffer based on the output of the first pass. - Update shader addressing math to use the output row size - Pass the output row size as "ncols_output", what used to be "ncols_output" is now "k" For the common case of K=40 and src0=(200000,1,1,1), this reduces the temporary buffer from about 3.2MB to 500KB. * vulkan: fix top_k bug when there are ties in the input I noticed by inspection a bug in the vulkan top_k shader where if the least value in the top_k appears multiple times we could end up writing those extra copies out rather than some larger values (if the larger values are on higher numbered threads). I rewrote the test verification to handle this case, where the final index set is not necessarily the same. * Update tests/test-backend-ops.cpp Co-authored-by: Georgi Gerganov <[email protected]> --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent e15cd06 commit a0f3897

File tree

3 files changed

+138
-37
lines changed

3 files changed

+138
-37
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4013,7 +4013,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
40134013
uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
40144014
sizeof(int) * device->subgroup_size +
40154015
2 * sizeof(int) +
4016-
(BLOCK_SIZE / device->subgroup_size) * sizeof(int);
4016+
2 * (BLOCK_SIZE / device->subgroup_size) * sizeof(int);
40174017
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
40184018
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
40194019
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);

ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ shared int counts[SUBGROUP_SIZE];
3838
shared int sh_min_idx;
3939
shared uint sh_total;
4040
shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
41+
shared uint eq_min_partials[BLOCK_SIZE / SUBGROUP_SIZE];
4142

4243
// Map float values to uint such that comparisons still work.
4344
// Positive values set the high bit, negative values are inverted.
@@ -156,25 +157,66 @@ void topk(const uint row) {
156157
// We need to compact these values to the start of the dst_row array.
157158
// Have each subgroup count how many items it'll store, so other
158159
// subgroups can compute their base offset.
159-
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
160-
uvec4 b = subgroupBallot(top);
161-
uint bit_count = subgroupBallotBitCount(b);
162-
if ((tid % SUBGROUP_SIZE) == 0) {
163-
offset_partials[tid / SUBGROUP_SIZE] = bit_count;
164-
}
165-
barrier();
160+
// Values strictly greater than range_min must be stored. For values equal
161+
// to range_min, there can be ties and it's possible we'll need to store
162+
// an arbitrary subset of them.
163+
// If total == p.k, have a fast path where we don't need to handle ties.
164+
if (total == p.k) {
165+
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
166+
uvec4 b = subgroupBallot(top);
167+
uint bit_count = subgroupBallotBitCount(b);
168+
if ((tid % SUBGROUP_SIZE) == 0) {
169+
offset_partials[tid / SUBGROUP_SIZE] = bit_count;
170+
}
171+
barrier();
166172

167-
uint out_idx = 0;
168-
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
169-
if (i < tid / SUBGROUP_SIZE) {
170-
out_idx += offset_partials[i];
173+
uint out_idx = 0;
174+
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
175+
if (i < tid / SUBGROUP_SIZE) {
176+
out_idx += offset_partials[i];
177+
}
171178
}
172-
}
173179

174-
uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
175-
if (top) {
176-
// TODO: Copy directly to the output?
177-
dst_row[out_idx + bit_count_ex] = v;
180+
uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
181+
if (top) {
182+
// TODO: Copy directly to the output?
183+
dst_row[out_idx + bit_count_ex] = v;
184+
}
185+
} else {
186+
bool top = f2ui(intBitsToFloat(v.y)) > range_min;
187+
bool eq_min = f2ui(intBitsToFloat(v.y)) == range_min;
188+
uvec4 b_top = subgroupBallot(top);
189+
uvec4 b_eq_min = subgroupBallot(eq_min);
190+
uint bit_count_top = subgroupBallotBitCount(b_top);
191+
uint bit_count_eq_min = subgroupBallotBitCount(b_eq_min);
192+
if ((tid % SUBGROUP_SIZE) == 0) {
193+
offset_partials[tid / SUBGROUP_SIZE] = bit_count_top;
194+
eq_min_partials[tid / SUBGROUP_SIZE] = bit_count_eq_min;
195+
}
196+
barrier();
197+
198+
uint out_idx = 0;
199+
uint eq_min_base = 0;
200+
uint eq_min_idx = 0;
201+
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
202+
if (i < tid / SUBGROUP_SIZE) {
203+
out_idx += offset_partials[i];
204+
eq_min_idx += eq_min_partials[i];
205+
}
206+
eq_min_base += offset_partials[i];
207+
}
208+
// range_min values are stored at the end
209+
eq_min_idx += eq_min_base;
210+
211+
uint bit_count_ex_top = subgroupBallotExclusiveBitCount(b_top);
212+
uint bit_count_ex_eq_min = subgroupBallotExclusiveBitCount(b_eq_min);
213+
if (top) {
214+
// TODO: Copy directly to the output?
215+
dst_row[out_idx + bit_count_ex_top] = v;
216+
}
217+
if (eq_min && eq_min_idx + bit_count_ex_eq_min < p.k) {
218+
dst_row[eq_min_idx + bit_count_ex_eq_min] = v;
219+
}
178220
}
179221

180222
barrier();

tests/test-backend-ops.cpp

Lines changed: 79 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,11 @@ static double nmse(const float * a, const float * b, size_t n) {
286286
return mse_a_b / mse_a_0;
287287
}
288288

289-
// difference between 2 integer sets (Jaccard distance, 0 - no difference, 1 - no overlap)
290-
static double jdst(const int32_t * a, const int32_t * b, size_t n) {
291-
std::unordered_map<int32_t, size_t> set_a;
292-
std::unordered_map<int32_t, size_t> set_b;
289+
// difference between 2 sets (Jaccard distance, 0 - no difference, 1 - no overlap)
290+
template <typename T>
291+
static double jdst(const T * a, const T * b, size_t n) {
292+
std::unordered_map<T, size_t> set_a;
293+
std::unordered_map<T, size_t> set_b;
293294

294295
for (size_t i = 0; i < n; ++i) {
295296
set_a[a[i]]++;
@@ -5001,42 +5002,94 @@ struct test_top_k : public test_case {
50015002
const ggml_type type;
50025003
const std::array<int64_t, 4> ne;
50035004
const int k;
5005+
const bool ties;
5006+
ggml_tensor * input {};
50045007

50055008
std::string vars() override {
5006-
return VARS_TO_STR3(type, ne, k);
5009+
return VARS_TO_STR4(type, ne, k, ties);
50075010
}
50085011

50095012
test_top_k(ggml_type type = GGML_TYPE_F32,
50105013
std::array<int64_t, 4> ne = {16, 10, 10, 10},
5011-
int k = 4)
5012-
: type(type), ne(ne), k(k) {}
5014+
int k = 4, bool ties = false)
5015+
: type(type), ne(ne), k(k), ties(ties) {}
50135016

50145017
double max_err() override {
50155018
return 0.0;
50165019
}
50175020

5021+
// When there are ties, only validate the final result.
5022+
// The logic in err can't handle the sentinel tensors.
5023+
bool run_whole_graph() override { return ties; }
5024+
50185025
double err(const float * a, const float * b, size_t n) override {
5019-
std::vector<int32_t> ia(n);
5020-
std::vector<int32_t> ib(n);
5026+
// When there are no ties, we expect the exact same set of indices,
5027+
// but possibly in a different order. When there are ties, the indices
5028+
// can be different but the input values they correspond to should be
5029+
// the same. The logic for ties could work for non-ties, but only for
5030+
// the output tensor, not for the sentinel tensors.
5031+
if (ties) {
5032+
std::vector<float> src(ggml_nelements(input));
5033+
5034+
ggml_backend_tensor_get(input, src.data(), 0, ggml_nelements(input) * ggml_type_size(type));
5035+
5036+
double diff = 0.0f;
5037+
5038+
GGML_ASSERT(n == (size_t)(ggml_nrows(input) * k));
5039+
int64_t cols = input->ne[0];
5040+
std::vector<int32_t> ia(k);
5041+
std::vector<int32_t> ib(k);
5042+
std::vector<float> asrc(k);
5043+
std::vector<float> bsrc(k);
5044+
for (int64_t r = 0; r < ggml_nrows(input); r++) {
5045+
// Convert indices for the row back to integer
5046+
for (int64_t c = 0; c < k; c++) {
5047+
ia[c] = (int32_t)a[r * k + c];
5048+
ib[c] = (int32_t)b[r * k + c];
5049+
}
5050+
// The src values for each row should match.
5051+
for (int64_t c = 0; c < k; c++) {
5052+
asrc[c] = src[r * cols + ia[c]];
5053+
bsrc[c] = src[r * cols + ib[c]];
5054+
}
5055+
diff += jdst(asrc.data(), bsrc.data(), k);
5056+
// There should be no duplicate indices
5057+
std::sort(ia.begin(), ia.end());
5058+
std::sort(ib.begin(), ib.end());
5059+
if (std::adjacent_find(ia.begin(), ia.end()) != ia.end()) {
5060+
diff += 1;
5061+
}
5062+
if (std::adjacent_find(ib.begin(), ib.end()) != ib.end()) {
5063+
diff += 1;
5064+
}
5065+
}
5066+
return diff;
5067+
} else {
5068+
std::vector<int32_t> ia(n);
5069+
std::vector<int32_t> ib(n);
50215070

5022-
double diff = 0.0f;
5071+
double diff = 0.0f;
50235072

5024-
for (size_t i = 0; i < n; i++) {
5025-
ia[i] = (int32_t) a[i];
5026-
ib[i] = (int32_t) b[i];
5073+
for (size_t i = 0; i < n; i++) {
5074+
ia[i] = (int32_t) a[i];
5075+
ib[i] = (int32_t) b[i];
50275076

5028-
// penalize the result if the data is not integer valued
5029-
diff += std::fabs(a[i] - ia[i]);
5030-
diff += std::fabs(b[i] - ib[i]);
5031-
}
5077+
// penalize the result if the data is not integer valued
5078+
diff += std::fabs(a[i] - ia[i]);
5079+
diff += std::fabs(b[i] - ib[i]);
5080+
}
50325081

5033-
return diff + jdst(ia.data(), ib.data(), n);
5082+
return diff + jdst(ia.data(), ib.data(), n);
5083+
}
50345084
}
50355085

50365086
ggml_tensor * build_graph(ggml_context * ctx) override {
50375087
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
50385088
ggml_set_name(a, "a");
50395089

5090+
// Save 'a' for err()
5091+
input = a;
5092+
50405093
ggml_tensor * out = ggml_top_k(ctx, a, k);
50415094
ggml_set_name(out, "out");
50425095

@@ -5047,11 +5100,16 @@ struct test_top_k : public test_case {
50475100
std::random_device rd;
50485101
std::default_random_engine rng(rd());
50495102
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
5050-
// initialize with unique values to avoid ties
5103+
int tie_denom = std::max(1, std::min(10, k / 2));
50515104
for (int64_t r = 0; r < ggml_nrows(t); r++) {
50525105
std::vector<float> data(t->ne[0]);
50535106
for (int i = 0; i < t->ne[0]; i++) {
5054-
data[i] = i;
5107+
if (ties) {
5108+
// integer division to introduce duplicates
5109+
data[i] = i / tie_denom;
5110+
} else {
5111+
data[i] = i;
5112+
}
50555113
}
50565114
std::shuffle(data.begin(), data.end(), rng);
50575115
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
@@ -7657,6 +7715,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
76577715
if (k <= 1<<i) {
76587716
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i), 1, 1, 1}, k));
76597717
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k));
7718+
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k, true));
76607719
}
76617720
}
76627721
}

0 commit comments

Comments
 (0)