Skip to content

Commit 2a12dee

Browse files
authored
Merge pull request boostorg#803 from rosenrodt/fix-dupe-with-custom-compare-bitonic-block-sort
Fix duplicates using custom compare with bitonic block sort
2 parents 2c16bbc + 3cba294 commit 2a12dee

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

include/boost/compute/algorithm/detail/merge_sort_on_gpu.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,12 @@ inline size_t bitonic_block_sort(KeyIterator keys_first,
170170
k.decl<bool>("compare") << " = " <<
171171
compare(k.var<key_type>("sibling_key"),
172172
k.var<key_type>("my_key")) << ";\n" <<
173+
k.decl<bool>("equal") << " = !(compare || " <<
174+
compare(k.var<key_type>("my_key"),
175+
k.var<key_type>("sibling_key")) << ");\n" <<
173176
k.decl<bool>("swap") <<
174177
" = compare ^ (sibling_idx < lid) ^ direction;\n" <<
178+
"swap = equal ? false : swap;\n" <<
175179
"my_key = swap ? sibling_key : my_key;\n";
176180
if(sort_by_key)
177181
{
@@ -220,8 +224,12 @@ inline size_t bitonic_block_sort(KeyIterator keys_first,
220224
k.decl<bool>("compare") << " = " <<
221225
compare(k.var<key_type>("sibling_key"),
222226
k.var<key_type>("my_key")) << ";\n" <<
227+
k.decl<bool>("equal") << " = !(compare || " <<
228+
compare(k.var<key_type>("my_key"),
229+
k.var<key_type>("sibling_key")) << ");\n" <<
223230
k.decl<bool>("swap") <<
224231
" = compare ^ (sibling_idx < lid);\n" <<
232+
"swap = equal ? false : swap;\n" <<
225233
"my_key = swap ? sibling_key : my_key;\n";
226234
if(sort_by_key)
227235
{

test/test_sort.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ BOOST_AUTO_TEST_CASE(sort_int2)
340340
host[size/4] = int2_(20.f, 0.f);
341341
host[(size*3)/4] = int2_(9.f, 0.f);
342342
host[size-3] = int2_(-10.0f, 0.f);
343+
host[size/2+1] = int2_(-10.0f, -1.f);
343344

344345
boost::compute::vector<int2_> vector(size, context);
345346
boost::compute::copy(host.begin(), host.end(), vector.begin(), queue);
@@ -356,9 +357,11 @@ BOOST_AUTO_TEST_CASE(sort_int2)
356357
);
357358
boost::compute::copy(vector.begin(), vector.end(), host.begin(), queue);
358359
BOOST_CHECK_CLOSE(host[0][0], -10.f, 0.1);
360+
BOOST_CHECK_CLOSE(host[1][0], -10.f, 0.1);
359361
BOOST_CHECK_CLOSE(host[(size - 3)][0], 9.f, 0.1);
360362
BOOST_CHECK_CLOSE(host[(size - 2)][0], 20.f, 0.1);
361363
BOOST_CHECK_CLOSE(host[(size - 1)][0], 100.f, 0.1);
364+
BOOST_CHECK_NE(host[0], host[1]);
362365
}
363366

364367
BOOST_AUTO_TEST_SUITE_END()

0 commit comments

Comments
 (0)