Skip to content

Commit fa78d2b

Browse files
committed
refactor common.hlsl and prep for new implumentation
1 parent d80f449 commit fa78d2b

File tree

1 file changed

+168
-34
lines changed

1 file changed

+168
-34
lines changed
Lines changed: 168 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#ifndef _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_
22
#define _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_
33

4-
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
5-
#include <nbl/builtin/hlsl/concepts.hlsl>
6-
#include <nbl/builtin/hlsl/math/intutil.hlsl>
7-
#include <nbl/builtin/hlsl/utility.hlsl>
4+
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
5+
#include "nbl/builtin/hlsl/type_traits.hlsl"
6+
#include "nbl/builtin/hlsl/workgroup/basic.hlsl"
7+
#include "nbl/builtin/hlsl/functional.hlsl"
8+
#include "nbl/builtin/hlsl/concepts/impl/base.hlsl"
89

910
namespace nbl
1011
{
@@ -13,41 +14,174 @@ namespace hlsl
1314
namespace bitonic_sort
1415
{
1516

16-
template<typename KeyType, typename ValueType, typename Comparator>
17-
void compareExchangeWithPartner(
18-
bool takeLarger,
19-
NBL_REF_ARG(pair<KeyType, ValueType>) loPair,
20-
NBL_CONST_REF_ARG(pair<KeyType, ValueType>) partnerLoPair,
21-
NBL_REF_ARG(pair<KeyType, ValueType>) hiPair,
22-
NBL_CONST_REF_ARG(pair<KeyType, ValueType>) partnerHiPair,
23-
NBL_CONST_REF_ARG(Comparator) comp)
24-
{
25-
const bool loSelfSmaller = comp(loPair.first, partnerLoPair.first);
26-
const bool takePartnerLo = takeLarger ? loSelfSmaller : !loSelfSmaller;
27-
if (takePartnerLo)
28-
loPair = partnerLoPair;
29-
30-
const bool hiSelfSmaller = comp(hiPair.first, partnerHiPair.first);
31-
const bool takePartnerHi = takeLarger ? hiSelfSmaller : !hiSelfSmaller;
32-
if (takePartnerHi)
33-
hiPair = partnerHiPair;
34-
}
35-
36-
template<typename KeyType, typename ValueType, typename Comparator>
37-
void compareSwap(
38-
bool ascending,
39-
NBL_REF_ARG(pair<KeyType, ValueType>) loPair,
40-
NBL_REF_ARG(pair<KeyType, ValueType>) hiPair,
41-
NBL_CONST_REF_ARG(Comparator) comp)
17+
template<typename KeyType, typename ValueType, uint32_t SubgroupSizelog2,typename Comparator>
18+
struct bitonic_sort_config
4219
{
43-
const bool shouldSwap = comp(hiPair.first, loPair.first);
44-
const bool doSwap = (shouldSwap == ascending);
20+
using key_t = KeyType;
21+
using value_t = ValueType;
22+
using comparator_t = Comparator;
23+
static const uint32_t SubgroupSize = SubgroupSizelog2;
24+
};
25+
26+
template<typename Config, class device_capabilities = void>
27+
struct bitonic_sort;
4528

46-
if (doSwap)
47-
swap(loPair, hiPair);
29+
30+
template<typename K, typename V>
31+
inline K get_key(NBL_CONST_REF_ARG(pair<K, V>) kv)
32+
{
33+
return kv.first;
4834
}
35+
36+
template<typename K, typename V>
37+
inline V get_value(NBL_CONST_REF_ARG(pair<K, V>) kv)
38+
{
39+
return kv.second;
4940
}
41+
42+
template<typename KeyType>
43+
struct WorkgroupType
44+
{
45+
using key_t = KeyType;
46+
using index_t = uint32_t;
47+
48+
key_t key;
49+
index_t workgroupRelativeIndex;
50+
};
51+
52+
53+
template<typename KeyType, uint32_t KeyBits, typename StorageType = uint32_t>
54+
struct SubgroupType
55+
{
56+
using key_t = KeyType;
57+
using storage_t = StorageType;
58+
59+
static const uint32_t IndexBits = sizeof(storage_t) * 8u - KeyBits;
60+
static const storage_t KeyMask = (storage_t(1u) << KeyBits) - 1u;
61+
static const storage_t IndexMask = ~KeyMask;
62+
63+
storage_t packed;
64+
65+
static inline SubgroupType create(key_t key, uint32_t subgroupIndex)
66+
{
67+
SubgroupType st;
68+
st.packed = (storage_t(key) & KeyMask) | (storage_t(subgroupIndex) << KeyBits);
69+
return st;
70+
}
71+
72+
inline key_t getKey() { return key_t(packed & KeyMask); }
73+
inline uint32_t getSubgroupIndex() { return packed >> KeyBits; }
74+
75+
inline WorkgroupType<key_t> toWorkgroupType(uint32_t subgroupID, uint32_t elementsPerSubgroup)
76+
{
77+
WorkgroupType<key_t> wg;
78+
wg.key = getKey();
79+
wg.workgroupRelativeIndex = workgroup::SubgroupContiguousIndex();
80+
return wg;
81+
}
82+
};
83+
84+
template<typename KeyType>
85+
using SubgroupType27 = SubgroupType<KeyType, 27u>;
86+
87+
template<typename K>
88+
inline K get_key(NBL_CONST_REF_ARG(WorkgroupType<K>) wt)
89+
{
90+
return wt.key;
5091
}
92+
93+
94+
template<typename K, uint32_t KeyBits, typename StorageType>
95+
inline K get_key(NBL_CONST_REF_ARG(SubgroupType<K, KeyBits, StorageType>) st)
96+
{
97+
return st.getKey();
5198
}
5299

100+
101+
//template<typename KeyType, typename Comp>
102+
//inline void compareSwap(
103+
// bool ascending,
104+
// NBL_REF_ARG(WorkgroupType<KeyType>) a,
105+
// NBL_REF_ARG(WorkgroupType<KeyType>) b,
106+
// NBL_CONST_REF_ARG(Comp) comp)
107+
//{
108+
// const bool swap = comp(b.key, a.key) == ascending;
109+
// WorkgroupType<KeyType> tmp = a;
110+
// a = swap ? b : a;
111+
// b = swap ? tmp : b;
112+
//}
113+
114+
115+
//template<typename KeyType, uint32_t KeyBits, typename StorageType, typename Comp>
116+
//inline void compareSwap(
117+
// bool ascending,
118+
// NBL_REF_ARG(SubgroupType<KeyType, KeyBits, StorageType>) a,
119+
// NBL_REF_ARG(SubgroupType<KeyType, KeyBits, StorageType>) b,
120+
// NBL_CONST_REF_ARG(Comp) comp)
121+
//{
122+
// const bool swap = comp(b.getKey(), a.getKey()) == ascending;
123+
// SubgroupType<KeyType, KeyBits, StorageType> tmp = a;
124+
// a = swap ? b : a;
125+
// b = swap ? tmp : b;
126+
//}
127+
128+
template<typename KeyValue, uint32_t Log2N, typename Comparator>
129+
struct LocalPasses
130+
{
131+
static const uint32_t N = 1u << Log2N;
132+
void operator()(bool ascending, KeyValue data[N], NBL_CONST_REF_ARG(Comparator) comp);
133+
};
134+
135+
template<typename KeyValue, typename Comparator>
136+
struct LocalPasses<KeyValue, 1, Comparator>
137+
{
138+
void operator()(bool ascending, KeyValue data[2], NBL_CONST_REF_ARG(Comparator) comp)
139+
{
140+
const bool swap = comp(get_key(data[1]), get_key(data[0])) == ascending;
141+
142+
KeyValue temp = data[0];
143+
data[0] = swap ? data[1] : data[0];
144+
data[1] = swap ? temp : data[1];
145+
}
146+
};
147+
148+
// Specialization for WorkgroupType with 2 elements
149+
template<typename KeyType, typename Comparator>
150+
struct LocalPasses<WorkgroupType<KeyType>, 1, Comparator>
151+
{
152+
static const uint32_t N = 2;
153+
154+
void operator()(bool ascending,
155+
WorkgroupType<KeyType> data[N],
156+
NBL_CONST_REF_ARG(Comparator) comp)
157+
{
158+
const bool swap = comp(get_key(data[1]), get_key(data[0])) == ascending;
159+
WorkgroupType<KeyType> tmp = data[0];
160+
data[0] = swap ? data[1] : data[0];
161+
data[1] = swap ? tmp : data[1];
162+
}
163+
};
164+
165+
// Specialization for SubgroupType with 2 elements
166+
template<typename KeyType, uint32_t KeyBits, typename StorageType, typename Comparator>
167+
struct LocalPasses<SubgroupType<KeyType, KeyBits, StorageType>, 1, Comparator>
168+
{
169+
static const uint32_t N = 2;
170+
171+
void operator()(bool ascending,
172+
SubgroupType<KeyType, KeyBits, StorageType> data[N],
173+
NBL_CONST_REF_ARG(Comparator) comp)
174+
{
175+
const bool swap = comp(get_key(data[1]), get_key(data[0])) == ascending;
176+
SubgroupType<KeyType, KeyBits, StorageType> tmp = data[0];
177+
data[0] = swap ? data[1] : data[0];
178+
data[1] = swap ? tmp : data[1];
179+
}
180+
};
181+
182+
183+
} // namespace bitonic_sort
184+
} // namespace hlsl
185+
} // namespace nbl
186+
53187
#endif

0 commit comments

Comments
 (0)