Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions include/nbl/builtin/hlsl/bitonic_sort/common.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_
#define _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_

#include <nbl/builtin/hlsl/cpp_compat.hlsl>
#include <nbl/builtin/hlsl/concepts.hlsl>
#include <nbl/builtin/hlsl/math/intutil.hlsl>

namespace nbl
{
namespace hlsl
{
namespace bitonic_sort
{

}
}
}

#endif
96 changes: 96 additions & 0 deletions include/nbl/builtin/hlsl/subgroup/bitonic_sort.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#ifndef NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED
#define NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED
#include "nbl/builtin/hlsl/bitonic_sort/common.hlsl"
#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl"
#include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
#include "nbl/builtin/hlsl/functional.hlsl"
namespace nbl
{
namespace hlsl
{
namespace subgroup
{
template<typename KeyType, typename ValueType, typename Comparator = less<KeyType> >
struct bitonic_sort_config
{
using key_t = KeyType;
using value_t = ValueType;
using comparator_t = Comparator;
};
template<bool Ascending, typename Config, class device_capabilities = void>
struct bitonic_sort;
template<bool Ascending, typename KeyType, typename ValueType, typename Comparator, class device_capabilities>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get that Ascending is used because when moving onto workgroup you're going to need to call alternating subgroup sorts. However, as a front-facing API if I wanted a single subgroup shuffle I'd usually want it in the order specified by the Comparator. Maybe push it after the Config and give it a default value of true. Or better yet, since Ascending can be confusing, consider calling it ReverseOrder or something simpler that conveys the intent better

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ascending and later names like takeLarger implicitly assume the comparator is less (lo and hi don't, those are related to the "lane" order in the bitonic sort diagram). That's fine on its own, it makes the code more readable vs naming them with a more generic option. However, there should be comments mentioning that names assume this implicitly so there's no confusion.

struct bitonic_sort<Ascending, bitonic_sort_config<KeyType, ValueType, Comparator>, device_capabilities>
{
using config_t = bitonic_sort_config<KeyType, ValueType, Comparator>;
using key_t = typename config_t::key_t;
using value_t = typename config_t::value_t;
using comparator_t = typename config_t::comparator_t;
// Thread-level compare and swap (operates on lo/hi in registers)
static void compareAndSwap(bool ascending, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
comparator_t comp;
const bool shouldSwap = ascending ? comp(hiKey, loKey) : comp(loKey, hiKey);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compiler is probably dumb and might not realize the right term is the negation of the left term. Ternaries in SPIR-V usually get compiled to an OpSelect which treats both terms after the ? not as branches to conditionally execute, but as operands whose result must be evaluated before the select operation runs. That is to say, if the compiler is stupid you're going to run two comparisons. If you make the right term the negation of the left one, CSE is likely to kick in and evaluate the comparison only once.

if (shouldSwap)
{
// Swap keys
key_t tempKey = loKey;
loKey = hiKey;
hiKey = tempKey;
// Swap values
value_t tempVal = loVal;
loVal = hiVal;
hiVal = tempVal;
}
}
static void __call(NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
const uint32_t invocationID = glsl::gl_SubgroupInvocationID();
const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2();
[unroll]
for (uint32_t stage = 0; stage <= subgroupSizeLog2; stage++)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't add indentation after compiler directives

{
const bool bitonicAscending = (stage == subgroupSizeLog2) ? Ascending : !bool(invocationID & (1u << stage));
// Passes within this stage
[unroll]
for (uint32_t pass = 0; pass <= stage; pass++)
{
const uint32_t stride = 1u << (stage - pass); // Element stride
const uint32_t threadStride = stride >> 1;
if (threadStride == 0)
{
// Local compare and swap for stage 0
compareAndSwap(bitonicAscending, loKey, hiKey, loVal, hiVal);
}
else
{
// Shuffle from partner using XOR
const key_t pLoKey = glsl::subgroupShuffleXor<key_t>(loKey, threadStride);
const key_t pHiKey = glsl::subgroupShuffleXor<key_t>(hiKey, threadStride);
const value_t pLoVal = glsl::subgroupShuffleXor<value_t>(loVal, threadStride);
const value_t pHiVal = glsl::subgroupShuffleXor<value_t>(hiVal, threadStride);
// Determine if we're upper or lower half
const bool upperHalf = bool(invocationID & threadStride);
const bool takeLarger = upperHalf == bitonicAscending;
comparator_t comp;
if (takeLarger)
{
if (comp(loKey, pLoKey)) { loKey = pLoKey; loVal = pLoVal; }
if (comp(hiKey, pHiKey)) { hiKey = pHiKey; hiVal = pHiVal; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this isn't reversed? Assume a less comparator, bitonicAscending = true for the current stage and upperHalf = true for the current thread. Then takeLarger semantically conveys that this thread wants to keep the larger values. And yet this code assigns the smaller values.

}
else
{
if (comp(pLoKey, loKey)) { loKey = pLoKey; loVal = pLoVal; }
if (comp(pHiKey, hiKey)) { hiKey = pHiKey; hiVal = pHiVal; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, if the compiler is dumb this code is very costly: half your threads in a subgroup will have upperHalf = true and the other half will have it set to false. Parallel code execution needs to be uniform across threads in the same SM, so this section of code will run twice: first some half of your threads (say, those in the upper half) will run, then the other half. This kills your throughput.

Inside each branch, the inner ifs will likely get compiled down to two OpSelects each. You can make this whole code branchless by doing

loKey = loCondition ? loKey : pLoKey;
loVal = loCondition ? loVal : pLoVal;
hiKey = hiCondition ? hiKey : pHiKey;
hiVal = hiCondition ? hiVal : pHiVal;

where loCondition and hiCondition are predicates that depend on both takeLarger and the result of the key comparison

}
}
}
}
}
};
}
}
}
#endif
6 changes: 6 additions & 0 deletions include/nbl/builtin/hlsl/workgroup/bitonic_sort.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
#include <nbl/builtin/hlsl/concepts.hlsl>
#include <nbl/builtin/hlsl/bitonic_sort/common.hlsl>

#ifndef _NBL_BUILTIN_HLSL_WORKGROUP_BITONIC_SORT_INCLUDED_
#define _NBL_BUILTIN_HLSL_WORKGROUP_BITONIC_SORT_INCLUDED_