-
Notifications
You must be signed in to change notification settings - Fork 66
bitonic_sort #940
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
bitonic_sort #940
Changes from all commits
5d6322a
264650c
268949e
73aa820
dcb5e6e
b84a4bd
779815e
78a307a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#ifndef _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_ | ||
#define _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_ | ||
|
||
#include <nbl/builtin/hlsl/cpp_compat.hlsl> | ||
#include <nbl/builtin/hlsl/concepts.hlsl> | ||
#include <nbl/builtin/hlsl/math/intutil.hlsl> | ||
|
||
namespace nbl | ||
{ | ||
namespace hlsl | ||
{ | ||
namespace bitonic_sort | ||
{ | ||
|
||
} | ||
} | ||
} | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
#ifndef NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED | ||
#define NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED | ||
#include "nbl/builtin/hlsl/bitonic_sort/common.hlsl" | ||
#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl" | ||
#include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl" | ||
#include "nbl/builtin/hlsl/functional.hlsl" | ||
namespace nbl | ||
{ | ||
namespace hlsl | ||
{ | ||
namespace subgroup | ||
{ | ||
template<typename KeyType, typename ValueType, typename Comparator = less<KeyType> > | ||
struct bitonic_sort_config | ||
{ | ||
using key_t = KeyType; | ||
using value_t = ValueType; | ||
using comparator_t = Comparator; | ||
}; | ||
template<bool Ascending, typename Config, class device_capabilities = void> | ||
struct bitonic_sort; | ||
template<bool Ascending, typename KeyType, typename ValueType, typename Comparator, class device_capabilities> | ||
struct bitonic_sort<Ascending, bitonic_sort_config<KeyType, ValueType, Comparator>, device_capabilities> | ||
{ | ||
using config_t = bitonic_sort_config<KeyType, ValueType, Comparator>; | ||
using key_t = typename config_t::key_t; | ||
using value_t = typename config_t::value_t; | ||
using comparator_t = typename config_t::comparator_t; | ||
// Thread-level compare and swap (operates on lo/hi in registers) | ||
static void compareAndSwap(bool ascending, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey, | ||
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal) | ||
{ | ||
comparator_t comp; | ||
const bool shouldSwap = ascending ? comp(hiKey, loKey) : comp(loKey, hiKey); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The compiler is probably dumb and might not realize the right term is the negation of the left term. Ternaries in SPIR-V usually get compiled to an |
||
if (shouldSwap) | ||
{ | ||
// Swap keys | ||
key_t tempKey = loKey; | ||
loKey = hiKey; | ||
hiKey = tempKey; | ||
// Swap values | ||
value_t tempVal = loVal; | ||
loVal = hiVal; | ||
hiVal = tempVal; | ||
} | ||
} | ||
static void __call(NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey, | ||
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal) | ||
{ | ||
const uint32_t invocationID = glsl::gl_SubgroupInvocationID(); | ||
const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2(); | ||
[unroll] | ||
for (uint32_t stage = 0; stage <= subgroupSizeLog2; stage++) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't add indentation after compiler directives |
||
{ | ||
const bool bitonicAscending = (stage == subgroupSizeLog2) ? Ascending : !bool(invocationID & (1u << stage)); | ||
// Passes within this stage | ||
[unroll] | ||
for (uint32_t pass = 0; pass <= stage; pass++) | ||
{ | ||
const uint32_t stride = 1u << (stage - pass); // Element stride | ||
const uint32_t threadStride = stride >> 1; | ||
if (threadStride == 0) | ||
{ | ||
// Local compare and swap for stage 0 | ||
compareAndSwap(bitonicAscending, loKey, hiKey, loVal, hiVal); | ||
} | ||
else | ||
{ | ||
// Shuffle from partner using XOR | ||
const key_t pLoKey = glsl::subgroupShuffleXor<key_t>(loKey, threadStride); | ||
const key_t pHiKey = glsl::subgroupShuffleXor<key_t>(hiKey, threadStride); | ||
const value_t pLoVal = glsl::subgroupShuffleXor<value_t>(loVal, threadStride); | ||
const value_t pHiVal = glsl::subgroupShuffleXor<value_t>(hiVal, threadStride); | ||
// Determine if we're upper or lower half | ||
const bool upperHalf = bool(invocationID & threadStride); | ||
const bool takeLarger = upperHalf == bitonicAscending; | ||
comparator_t comp; | ||
if (takeLarger) | ||
{ | ||
if (comp(loKey, pLoKey)) { loKey = pLoKey; loVal = pLoVal; } | ||
if (comp(hiKey, pHiKey)) { hiKey = pHiKey; hiVal = pHiVal; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure this isn't reversed? Assume a |
||
} | ||
else | ||
{ | ||
if (comp(pLoKey, loKey)) { loKey = pLoKey; loVal = pLoVal; } | ||
if (comp(pHiKey, hiKey)) { hiKey = pHiKey; hiVal = pHiVal; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, if the compiler is dumb this code is very costly: half your threads in a subgroup will have Inside each branch, the inner loKey = loCondition ? loKey : pLoKey;
loVal = loCondition ? loVal : pLoVal;
hiKey = hiCondition ? hiKey : pHiKey;
hiVal = hiCondition ? hiVal : pHiVal; where |
||
} | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
} | ||
} | ||
} | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#include <nbl/builtin/hlsl/cpp_compat.hlsl> | ||
#include <nbl/builtin/hlsl/concepts.hlsl> | ||
#include <nbl/builtin/hlsl/bitonic_sort/common.hlsl> | ||
|
||
#ifndef _NBL_BUILTIN_HLSL_WORKGROUP_BITONIC_SORT_INCLUDED_ | ||
#define _NBL_BUILTIN_HLSL_WORKGROUP_BITONIC_SORT_INCLUDED_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get that
Ascending
is used because when moving onto workgroup you're going to need to call alternating subgroup sorts. However, as a front-facing API if I wanted a single subgroup shuffle I'd usually want it in the order specified by theComparator
. Maybe push it after theConfig
and give it a default value oftrue
. Or better yet, sinceAscending
can be confusing, consider calling itReverseOrder
or something simpler that conveys the intent betterThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ascending
and later names liketakeLarger
implicitly assume the comparator isless
(lo
andhi
don't, those are related to the "lane" order in the bitonic sort diagram). That's fine on its own, it makes the code more readable vs naming them with a more generic option. However, there should be comments mentioning that names assume this implicitly so there's no confusion.